Update Ollama to use streaming API (#138177)

* Update ollama to use streaming APIs

* Remove unnecessary logging

* Update ollama to use streaming APIs

* Remove unnecessary logging

* Update homeassistant/components/ollama/conversation.py

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>

---------

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
Allen Porter 2025-02-09 21:05:41 -08:00 committed by GitHub
parent ae38f89728
commit 15223b3679
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 167 additions and 74 deletions

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import AsyncGenerator, Callable
import json import json
import logging import logging
from typing import Any, Literal from typing import Any, Literal
@ -123,7 +123,47 @@ def _convert_content(
role=MessageRole.SYSTEM.value, role=MessageRole.SYSTEM.value,
content=chat_content.content, content=chat_content.content,
) )
raise ValueError(f"Unexpected content type: {type(chat_content)}") raise TypeError(f"Unexpected content type: {type(chat_content)}")
async def _transform_stream(
result: AsyncGenerator[ollama.Message],
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
"""Transform the response stream into HA format.
An Ollama streaming response may come in chunks like this:
response: message=Message(role="assistant", content="Paris")
response: message=Message(role="assistant", content=".")
response: message=Message(role="assistant", content=""), done: True, done_reason: "stop"
response: message=Message(role="assistant", tool_calls=[...])
response: message=Message(role="assistant", content=""), done: True, done_reason: "stop"
This generator conforms to the chatlog delta stream expectations in that it
yields deltas, then the role only once the response is done.
"""
new_msg = True
async for response in result:
_LOGGER.debug("Received response: %s", response)
response_message = response["message"]
chunk: conversation.AssistantContentDeltaDict = {}
if new_msg:
new_msg = False
chunk["role"] = "assistant"
if (tool_calls := response_message.get("tool_calls")) is not None:
chunk["tool_calls"] = [
llm.ToolInput(
tool_name=tool_call["function"]["name"],
tool_args=_parse_tool_args(tool_call["function"]["arguments"]),
)
for tool_call in tool_calls
]
if (content := response_message.get("content")) is not None:
chunk["content"] = content
if response_message.get("done"):
new_msg = True
yield chunk
class OllamaConversationEntity( class OllamaConversationEntity(
@ -216,12 +256,12 @@ class OllamaConversationEntity(
# To prevent infinite loops, we limit the number of iterations # To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS): for _iteration in range(MAX_TOOL_ITERATIONS):
try: try:
response = await client.chat( response_generator = await client.chat(
model=model, model=model,
# Make a copy of the messages because we mutate the list later # Make a copy of the messages because we mutate the list later
messages=list(message_history.messages), messages=list(message_history.messages),
tools=tools, tools=tools,
stream=False, stream=True,
# keep_alive requires specifying unit. In this case, seconds # keep_alive requires specifying unit. In this case, seconds
keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s", keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s",
options={CONF_NUM_CTX: settings.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)}, options={CONF_NUM_CTX: settings.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)},
@ -232,46 +272,25 @@ class OllamaConversationEntity(
f"Sorry, I had a problem talking to the Ollama server: {err}" f"Sorry, I had a problem talking to the Ollama server: {err}"
) from err ) from err
response_message = response["message"]
content = response_message.get("content")
tool_calls = response_message.get("tool_calls")
message_history.messages.append(
ollama.Message(
role=response_message["role"],
content=content,
tool_calls=tool_calls,
)
)
tool_inputs = [
llm.ToolInput(
tool_name=tool_call["function"]["name"],
tool_args=_parse_tool_args(tool_call["function"]["arguments"]),
)
for tool_call in tool_calls or ()
]
message_history.messages.extend( message_history.messages.extend(
[ [
ollama.Message( _convert_content(content)
role=MessageRole.TOOL.value, async for content in chat_log.async_add_delta_content_stream(
content=json.dumps(tool_response.tool_result), user_input.agent_id, _transform_stream(response_generator)
)
async for tool_response in chat_log.async_add_assistant_content(
conversation.AssistantContent(
agent_id=user_input.agent_id,
content=content,
tool_calls=tool_inputs or None,
)
) )
] ]
) )
if not tool_calls: if not chat_log.unresponded_tool_results:
break break
# Create intent response # Create intent response
intent_response = intent.IntentResponse(language=user_input.language) intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response_message["content"]) if not isinstance(chat_log.content[-1], conversation.AssistantContent):
raise TypeError(
f"Unexpected last message type: {type(chat_log.content[-1])}"
)
intent_response.async_set_speech(chat_log.content[-1].content or "")
return conversation.ConversationResult( return conversation.ConversationResult(
response=intent_response, conversation_id=chat_log.conversation_id response=intent_response, conversation_id=chat_log.conversation_id
) )

View File

@ -1,5 +1,6 @@
"""Tests for the Ollama integration.""" """Tests for the Ollama integration."""
from collections.abc import AsyncGenerator
from typing import Any from typing import Any
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
@ -25,6 +26,14 @@ def mock_ulid_tools():
yield yield
async def stream_generator(response: dict | list[dict]) -> AsyncGenerator[dict]:
"""Generate a response from the assistant."""
if not isinstance(response, list):
response = [response]
for msg in response:
yield msg
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"]) @pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"])
async def test_chat( async def test_chat(
hass: HomeAssistant, hass: HomeAssistant,
@ -42,7 +51,9 @@ async def test_chat(
with patch( with patch(
"ollama.AsyncClient.chat", "ollama.AsyncClient.chat",
return_value={"message": {"role": "assistant", "content": "test response"}}, return_value=stream_generator(
{"message": {"role": "assistant", "content": "test response"}}
),
) as mock_chat: ) as mock_chat:
result = await conversation.async_converse( result = await conversation.async_converse(
hass, hass,
@ -81,6 +92,53 @@ async def test_chat(
assert "Current time is" in detail_event["data"]["messages"][0]["content"] assert "Current time is" in detail_event["data"]["messages"][0]["content"]
async def test_chat_stream(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
) -> None:
"""Test chat messages are assembled across streamed responses."""
entry = MockConfigEntry()
entry.add_to_hass(hass)
with patch(
"ollama.AsyncClient.chat",
return_value=stream_generator(
[
{"message": {"role": "assistant", "content": "test "}},
{
"message": {"role": "assistant", "content": "response"},
"done": True,
"done_reason": "stop",
},
],
),
) as mock_chat:
result = await conversation.async_converse(
hass,
"test message",
None,
Context(),
agent_id=mock_config_entry.entry_id,
)
assert mock_chat.call_count == 1
args = mock_chat.call_args.kwargs
prompt = args["messages"][0]["content"]
assert args["model"] == "test model"
assert args["messages"] == [
Message(role="system", content=prompt),
Message(role="user", content="test message"),
]
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE, (
result
)
assert result.response.speech["plain"]["speech"] == "test response"
async def test_template_variables( async def test_template_variables(
hass: HomeAssistant, mock_config_entry: MockConfigEntry hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None: ) -> None:
@ -103,7 +161,9 @@ async def test_template_variables(
patch("ollama.AsyncClient.list"), patch("ollama.AsyncClient.list"),
patch( patch(
"ollama.AsyncClient.chat", "ollama.AsyncClient.chat",
return_value={"message": {"role": "assistant", "content": "test response"}}, return_value=stream_generator(
{"message": {"role": "assistant", "content": "test response"}}
),
) as mock_chat, ) as mock_chat,
patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user), patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user),
): ):
@ -170,26 +230,30 @@ async def test_function_call(
def completion_result(*args, messages, **kwargs): def completion_result(*args, messages, **kwargs):
for message in messages: for message in messages:
if message["role"] == "tool": if message["role"] == "tool":
return { return stream_generator(
"message": {
"role": "assistant",
"content": "I have successfully called the function",
}
}
return {
"message": {
"role": "assistant",
"tool_calls": [
{ {
"function": { "message": {
"name": "test_tool", "role": "assistant",
"arguments": tool_args, "content": "I have successfully called the function",
} }
} }
], )
return stream_generator(
{
"message": {
"role": "assistant",
"tool_calls": [
{
"function": {
"name": "test_tool",
"arguments": tool_args,
}
}
],
}
} }
} )
with patch( with patch(
"ollama.AsyncClient.chat", "ollama.AsyncClient.chat",
@ -251,26 +315,30 @@ async def test_function_exception(
def completion_result(*args, messages, **kwargs): def completion_result(*args, messages, **kwargs):
for message in messages: for message in messages:
if message["role"] == "tool": if message["role"] == "tool":
return { return stream_generator(
"message": {
"role": "assistant",
"content": "There was an error calling the function",
}
}
return {
"message": {
"role": "assistant",
"tool_calls": [
{ {
"function": { "message": {
"name": "test_tool", "role": "assistant",
"arguments": {"param1": "test_value"}, "content": "There was an error calling the function",
} }
} }
], )
return stream_generator(
{
"message": {
"role": "assistant",
"tool_calls": [
{
"function": {
"name": "test_tool",
"arguments": {"param1": "test_value"},
}
}
],
}
} }
} )
with patch( with patch(
"ollama.AsyncClient.chat", "ollama.AsyncClient.chat",
@ -344,7 +412,9 @@ async def test_message_history_trimming(
def response(*args, **kwargs) -> dict: def response(*args, **kwargs) -> dict:
nonlocal response_idx nonlocal response_idx
response_idx += 1 response_idx += 1
return {"message": {"role": "assistant", "content": f"response {response_idx}"}} return stream_generator(
{"message": {"role": "assistant", "content": f"response {response_idx}"}}
)
with patch( with patch(
"ollama.AsyncClient.chat", "ollama.AsyncClient.chat",
@ -438,11 +508,13 @@ async def test_message_history_unlimited(
"""Test that message history is not trimmed when max_history = 0.""" """Test that message history is not trimmed when max_history = 0."""
conversation_id = "1234" conversation_id = "1234"
def stream(*args, **kwargs) -> AsyncGenerator[dict]:
return stream_generator(
{"message": {"role": "assistant", "content": "test response"}}
)
with ( with (
patch( patch("ollama.AsyncClient.chat", side_effect=stream) as mock_chat,
"ollama.AsyncClient.chat",
return_value={"message": {"role": "assistant", "content": "test response"}},
) as mock_chat,
): ):
hass.config_entries.async_update_entry( hass.config_entries.async_update_entry(
mock_config_entry, options={ollama.CONF_MAX_HISTORY: 0} mock_config_entry, options={ollama.CONF_MAX_HISTORY: 0}
@ -559,7 +631,9 @@ async def test_options(
"""Test that options are passed correctly to ollama client.""" """Test that options are passed correctly to ollama client."""
with patch( with patch(
"ollama.AsyncClient.chat", "ollama.AsyncClient.chat",
return_value={"message": {"role": "assistant", "content": "test response"}}, return_value=stream_generator(
{"message": {"role": "assistant", "content": "test response"}}
),
) as mock_chat: ) as mock_chat:
await conversation.async_converse( await conversation.async_converse(
hass, hass,