mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Update Ollama to use streaming API (#138177)
* Update ollama to use streaming APIs * Remove unnecessary logging * Update ollama to use streaming APIs * Remove unnecessary logging * Update homeassistant/components/ollama/conversation.py Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io> --------- Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
parent
ae38f89728
commit
15223b3679
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
@ -123,7 +123,47 @@ def _convert_content(
|
||||
role=MessageRole.SYSTEM.value,
|
||||
content=chat_content.content,
|
||||
)
|
||||
raise ValueError(f"Unexpected content type: {type(chat_content)}")
|
||||
raise TypeError(f"Unexpected content type: {type(chat_content)}")
|
||||
|
||||
|
||||
async def _transform_stream(
|
||||
result: AsyncGenerator[ollama.Message],
|
||||
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||
"""Transform the response stream into HA format.
|
||||
|
||||
An Ollama streaming response may come in chunks like this:
|
||||
|
||||
response: message=Message(role="assistant", content="Paris")
|
||||
response: message=Message(role="assistant", content=".")
|
||||
response: message=Message(role="assistant", content=""), done: True, done_reason: "stop"
|
||||
response: message=Message(role="assistant", tool_calls=[...])
|
||||
response: message=Message(role="assistant", content=""), done: True, done_reason: "stop"
|
||||
|
||||
This generator conforms to the chatlog delta stream expectations in that it
|
||||
yields deltas, then the role only once the response is done.
|
||||
"""
|
||||
|
||||
new_msg = True
|
||||
async for response in result:
|
||||
_LOGGER.debug("Received response: %s", response)
|
||||
response_message = response["message"]
|
||||
chunk: conversation.AssistantContentDeltaDict = {}
|
||||
if new_msg:
|
||||
new_msg = False
|
||||
chunk["role"] = "assistant"
|
||||
if (tool_calls := response_message.get("tool_calls")) is not None:
|
||||
chunk["tool_calls"] = [
|
||||
llm.ToolInput(
|
||||
tool_name=tool_call["function"]["name"],
|
||||
tool_args=_parse_tool_args(tool_call["function"]["arguments"]),
|
||||
)
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
if (content := response_message.get("content")) is not None:
|
||||
chunk["content"] = content
|
||||
if response_message.get("done"):
|
||||
new_msg = True
|
||||
yield chunk
|
||||
|
||||
|
||||
class OllamaConversationEntity(
|
||||
@ -216,12 +256,12 @@ class OllamaConversationEntity(
|
||||
# To prevent infinite loops, we limit the number of iterations
|
||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||
try:
|
||||
response = await client.chat(
|
||||
response_generator = await client.chat(
|
||||
model=model,
|
||||
# Make a copy of the messages because we mutate the list later
|
||||
messages=list(message_history.messages),
|
||||
tools=tools,
|
||||
stream=False,
|
||||
stream=True,
|
||||
# keep_alive requires specifying unit. In this case, seconds
|
||||
keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s",
|
||||
options={CONF_NUM_CTX: settings.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)},
|
||||
@ -232,46 +272,25 @@ class OllamaConversationEntity(
|
||||
f"Sorry, I had a problem talking to the Ollama server: {err}"
|
||||
) from err
|
||||
|
||||
response_message = response["message"]
|
||||
content = response_message.get("content")
|
||||
tool_calls = response_message.get("tool_calls")
|
||||
message_history.messages.append(
|
||||
ollama.Message(
|
||||
role=response_message["role"],
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
)
|
||||
tool_inputs = [
|
||||
llm.ToolInput(
|
||||
tool_name=tool_call["function"]["name"],
|
||||
tool_args=_parse_tool_args(tool_call["function"]["arguments"]),
|
||||
)
|
||||
for tool_call in tool_calls or ()
|
||||
]
|
||||
|
||||
message_history.messages.extend(
|
||||
[
|
||||
ollama.Message(
|
||||
role=MessageRole.TOOL.value,
|
||||
content=json.dumps(tool_response.tool_result),
|
||||
)
|
||||
async for tool_response in chat_log.async_add_assistant_content(
|
||||
conversation.AssistantContent(
|
||||
agent_id=user_input.agent_id,
|
||||
content=content,
|
||||
tool_calls=tool_inputs or None,
|
||||
)
|
||||
_convert_content(content)
|
||||
async for content in chat_log.async_add_delta_content_stream(
|
||||
user_input.agent_id, _transform_stream(response_generator)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
if not chat_log.unresponded_tool_results:
|
||||
break
|
||||
|
||||
# Create intent response
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_speech(response_message["content"])
|
||||
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
||||
raise TypeError(
|
||||
f"Unexpected last message type: {type(chat_log.content[-1])}"
|
||||
)
|
||||
intent_response.async_set_speech(chat_log.content[-1].content or "")
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=chat_log.conversation_id
|
||||
)
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Tests for the Ollama integration."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
@ -25,6 +26,14 @@ def mock_ulid_tools():
|
||||
yield
|
||||
|
||||
|
||||
async def stream_generator(response: dict | list[dict]) -> AsyncGenerator[dict]:
|
||||
"""Generate a response from the assistant."""
|
||||
if not isinstance(response, list):
|
||||
response = [response]
|
||||
for msg in response:
|
||||
yield msg
|
||||
|
||||
|
||||
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"])
|
||||
async def test_chat(
|
||||
hass: HomeAssistant,
|
||||
@ -42,7 +51,9 @@ async def test_chat(
|
||||
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
return_value={"message": {"role": "assistant", "content": "test response"}},
|
||||
return_value=stream_generator(
|
||||
{"message": {"role": "assistant", "content": "test response"}}
|
||||
),
|
||||
) as mock_chat:
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
@ -81,6 +92,53 @@ async def test_chat(
|
||||
assert "Current time is" in detail_event["data"]["messages"][0]["content"]
|
||||
|
||||
|
||||
async def test_chat_stream(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_init_component,
|
||||
) -> None:
|
||||
"""Test chat messages are assembled across streamed responses."""
|
||||
|
||||
entry = MockConfigEntry()
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
return_value=stream_generator(
|
||||
[
|
||||
{"message": {"role": "assistant", "content": "test "}},
|
||||
{
|
||||
"message": {"role": "assistant", "content": "response"},
|
||||
"done": True,
|
||||
"done_reason": "stop",
|
||||
},
|
||||
],
|
||||
),
|
||||
) as mock_chat:
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"test message",
|
||||
None,
|
||||
Context(),
|
||||
agent_id=mock_config_entry.entry_id,
|
||||
)
|
||||
|
||||
assert mock_chat.call_count == 1
|
||||
args = mock_chat.call_args.kwargs
|
||||
prompt = args["messages"][0]["content"]
|
||||
|
||||
assert args["model"] == "test model"
|
||||
assert args["messages"] == [
|
||||
Message(role="system", content=prompt),
|
||||
Message(role="user", content="test message"),
|
||||
]
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE, (
|
||||
result
|
||||
)
|
||||
assert result.response.speech["plain"]["speech"] == "test response"
|
||||
|
||||
|
||||
async def test_template_variables(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> None:
|
||||
@ -103,7 +161,9 @@ async def test_template_variables(
|
||||
patch("ollama.AsyncClient.list"),
|
||||
patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
return_value={"message": {"role": "assistant", "content": "test response"}},
|
||||
return_value=stream_generator(
|
||||
{"message": {"role": "assistant", "content": "test response"}}
|
||||
),
|
||||
) as mock_chat,
|
||||
patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user),
|
||||
):
|
||||
@ -170,26 +230,30 @@ async def test_function_call(
|
||||
def completion_result(*args, messages, **kwargs):
|
||||
for message in messages:
|
||||
if message["role"] == "tool":
|
||||
return {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I have successfully called the function",
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
return stream_generator(
|
||||
{
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": tool_args,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I have successfully called the function",
|
||||
}
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
return stream_generator(
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": tool_args,
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
@ -251,26 +315,30 @@ async def test_function_exception(
|
||||
def completion_result(*args, messages, **kwargs):
|
||||
for message in messages:
|
||||
if message["role"] == "tool":
|
||||
return {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "There was an error calling the function",
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
return stream_generator(
|
||||
{
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": {"param1": "test_value"},
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "There was an error calling the function",
|
||||
}
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
return stream_generator(
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": {"param1": "test_value"},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
@ -344,7 +412,9 @@ async def test_message_history_trimming(
|
||||
def response(*args, **kwargs) -> dict:
|
||||
nonlocal response_idx
|
||||
response_idx += 1
|
||||
return {"message": {"role": "assistant", "content": f"response {response_idx}"}}
|
||||
return stream_generator(
|
||||
{"message": {"role": "assistant", "content": f"response {response_idx}"}}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
@ -438,11 +508,13 @@ async def test_message_history_unlimited(
|
||||
"""Test that message history is not trimmed when max_history = 0."""
|
||||
conversation_id = "1234"
|
||||
|
||||
def stream(*args, **kwargs) -> AsyncGenerator[dict]:
|
||||
return stream_generator(
|
||||
{"message": {"role": "assistant", "content": "test response"}}
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
return_value={"message": {"role": "assistant", "content": "test response"}},
|
||||
) as mock_chat,
|
||||
patch("ollama.AsyncClient.chat", side_effect=stream) as mock_chat,
|
||||
):
|
||||
hass.config_entries.async_update_entry(
|
||||
mock_config_entry, options={ollama.CONF_MAX_HISTORY: 0}
|
||||
@ -559,7 +631,9 @@ async def test_options(
|
||||
"""Test that options are passed correctly to ollama client."""
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
return_value={"message": {"role": "assistant", "content": "test response"}},
|
||||
return_value=stream_generator(
|
||||
{"message": {"role": "assistant", "content": "test response"}}
|
||||
),
|
||||
) as mock_chat:
|
||||
await conversation.async_converse(
|
||||
hass,
|
||||
|
Loading…
x
Reference in New Issue
Block a user