mirror of
https://github.com/home-assistant/core.git
synced 2025-07-25 22:27:07 +00:00
Update Ollama to use streaming API (#138177)
* Update ollama to use streaming APIs * Remove unnecessary logging * Update ollama to use streaming APIs * Remove unnecessary logging * Update homeassistant/components/ollama/conversation.py Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io> --------- Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
parent
ae38f89728
commit
15223b3679
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import AsyncGenerator, Callable
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
@ -123,7 +123,47 @@ def _convert_content(
|
|||||||
role=MessageRole.SYSTEM.value,
|
role=MessageRole.SYSTEM.value,
|
||||||
content=chat_content.content,
|
content=chat_content.content,
|
||||||
)
|
)
|
||||||
raise ValueError(f"Unexpected content type: {type(chat_content)}")
|
raise TypeError(f"Unexpected content type: {type(chat_content)}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _transform_stream(
|
||||||
|
result: AsyncGenerator[ollama.Message],
|
||||||
|
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||||
|
"""Transform the response stream into HA format.
|
||||||
|
|
||||||
|
An Ollama streaming response may come in chunks like this:
|
||||||
|
|
||||||
|
response: message=Message(role="assistant", content="Paris")
|
||||||
|
response: message=Message(role="assistant", content=".")
|
||||||
|
response: message=Message(role="assistant", content=""), done: True, done_reason: "stop"
|
||||||
|
response: message=Message(role="assistant", tool_calls=[...])
|
||||||
|
response: message=Message(role="assistant", content=""), done: True, done_reason: "stop"
|
||||||
|
|
||||||
|
This generator conforms to the chatlog delta stream expectations in that it
|
||||||
|
yields deltas, then the role only once the response is done.
|
||||||
|
"""
|
||||||
|
|
||||||
|
new_msg = True
|
||||||
|
async for response in result:
|
||||||
|
_LOGGER.debug("Received response: %s", response)
|
||||||
|
response_message = response["message"]
|
||||||
|
chunk: conversation.AssistantContentDeltaDict = {}
|
||||||
|
if new_msg:
|
||||||
|
new_msg = False
|
||||||
|
chunk["role"] = "assistant"
|
||||||
|
if (tool_calls := response_message.get("tool_calls")) is not None:
|
||||||
|
chunk["tool_calls"] = [
|
||||||
|
llm.ToolInput(
|
||||||
|
tool_name=tool_call["function"]["name"],
|
||||||
|
tool_args=_parse_tool_args(tool_call["function"]["arguments"]),
|
||||||
|
)
|
||||||
|
for tool_call in tool_calls
|
||||||
|
]
|
||||||
|
if (content := response_message.get("content")) is not None:
|
||||||
|
chunk["content"] = content
|
||||||
|
if response_message.get("done"):
|
||||||
|
new_msg = True
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
class OllamaConversationEntity(
|
class OllamaConversationEntity(
|
||||||
@ -216,12 +256,12 @@ class OllamaConversationEntity(
|
|||||||
# To prevent infinite loops, we limit the number of iterations
|
# To prevent infinite loops, we limit the number of iterations
|
||||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||||
try:
|
try:
|
||||||
response = await client.chat(
|
response_generator = await client.chat(
|
||||||
model=model,
|
model=model,
|
||||||
# Make a copy of the messages because we mutate the list later
|
# Make a copy of the messages because we mutate the list later
|
||||||
messages=list(message_history.messages),
|
messages=list(message_history.messages),
|
||||||
tools=tools,
|
tools=tools,
|
||||||
stream=False,
|
stream=True,
|
||||||
# keep_alive requires specifying unit. In this case, seconds
|
# keep_alive requires specifying unit. In this case, seconds
|
||||||
keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s",
|
keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s",
|
||||||
options={CONF_NUM_CTX: settings.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)},
|
options={CONF_NUM_CTX: settings.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)},
|
||||||
@ -232,46 +272,25 @@ class OllamaConversationEntity(
|
|||||||
f"Sorry, I had a problem talking to the Ollama server: {err}"
|
f"Sorry, I had a problem talking to the Ollama server: {err}"
|
||||||
) from err
|
) from err
|
||||||
|
|
||||||
response_message = response["message"]
|
|
||||||
content = response_message.get("content")
|
|
||||||
tool_calls = response_message.get("tool_calls")
|
|
||||||
message_history.messages.append(
|
|
||||||
ollama.Message(
|
|
||||||
role=response_message["role"],
|
|
||||||
content=content,
|
|
||||||
tool_calls=tool_calls,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
tool_inputs = [
|
|
||||||
llm.ToolInput(
|
|
||||||
tool_name=tool_call["function"]["name"],
|
|
||||||
tool_args=_parse_tool_args(tool_call["function"]["arguments"]),
|
|
||||||
)
|
|
||||||
for tool_call in tool_calls or ()
|
|
||||||
]
|
|
||||||
|
|
||||||
message_history.messages.extend(
|
message_history.messages.extend(
|
||||||
[
|
[
|
||||||
ollama.Message(
|
_convert_content(content)
|
||||||
role=MessageRole.TOOL.value,
|
async for content in chat_log.async_add_delta_content_stream(
|
||||||
content=json.dumps(tool_response.tool_result),
|
user_input.agent_id, _transform_stream(response_generator)
|
||||||
)
|
|
||||||
async for tool_response in chat_log.async_add_assistant_content(
|
|
||||||
conversation.AssistantContent(
|
|
||||||
agent_id=user_input.agent_id,
|
|
||||||
content=content,
|
|
||||||
tool_calls=tool_inputs or None,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
if not tool_calls:
|
if not chat_log.unresponded_tool_results:
|
||||||
break
|
break
|
||||||
|
|
||||||
# Create intent response
|
# Create intent response
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
intent_response = intent.IntentResponse(language=user_input.language)
|
||||||
intent_response.async_set_speech(response_message["content"])
|
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
||||||
|
raise TypeError(
|
||||||
|
f"Unexpected last message type: {type(chat_log.content[-1])}"
|
||||||
|
)
|
||||||
|
intent_response.async_set_speech(chat_log.content[-1].content or "")
|
||||||
return conversation.ConversationResult(
|
return conversation.ConversationResult(
|
||||||
response=intent_response, conversation_id=chat_log.conversation_id
|
response=intent_response, conversation_id=chat_log.conversation_id
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Tests for the Ollama integration."""
|
"""Tests for the Ollama integration."""
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
@ -25,6 +26,14 @@ def mock_ulid_tools():
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_generator(response: dict | list[dict]) -> AsyncGenerator[dict]:
|
||||||
|
"""Generate a response from the assistant."""
|
||||||
|
if not isinstance(response, list):
|
||||||
|
response = [response]
|
||||||
|
for msg in response:
|
||||||
|
yield msg
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"])
|
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"])
|
||||||
async def test_chat(
|
async def test_chat(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
@ -42,7 +51,9 @@ async def test_chat(
|
|||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"ollama.AsyncClient.chat",
|
"ollama.AsyncClient.chat",
|
||||||
return_value={"message": {"role": "assistant", "content": "test response"}},
|
return_value=stream_generator(
|
||||||
|
{"message": {"role": "assistant", "content": "test response"}}
|
||||||
|
),
|
||||||
) as mock_chat:
|
) as mock_chat:
|
||||||
result = await conversation.async_converse(
|
result = await conversation.async_converse(
|
||||||
hass,
|
hass,
|
||||||
@ -81,6 +92,53 @@ async def test_chat(
|
|||||||
assert "Current time is" in detail_event["data"]["messages"][0]["content"]
|
assert "Current time is" in detail_event["data"]["messages"][0]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_chat_stream(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_init_component,
|
||||||
|
) -> None:
|
||||||
|
"""Test chat messages are assembled across streamed responses."""
|
||||||
|
|
||||||
|
entry = MockConfigEntry()
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"ollama.AsyncClient.chat",
|
||||||
|
return_value=stream_generator(
|
||||||
|
[
|
||||||
|
{"message": {"role": "assistant", "content": "test "}},
|
||||||
|
{
|
||||||
|
"message": {"role": "assistant", "content": "response"},
|
||||||
|
"done": True,
|
||||||
|
"done_reason": "stop",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
) as mock_chat:
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"test message",
|
||||||
|
None,
|
||||||
|
Context(),
|
||||||
|
agent_id=mock_config_entry.entry_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_chat.call_count == 1
|
||||||
|
args = mock_chat.call_args.kwargs
|
||||||
|
prompt = args["messages"][0]["content"]
|
||||||
|
|
||||||
|
assert args["model"] == "test model"
|
||||||
|
assert args["messages"] == [
|
||||||
|
Message(role="system", content=prompt),
|
||||||
|
Message(role="user", content="test message"),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE, (
|
||||||
|
result
|
||||||
|
)
|
||||||
|
assert result.response.speech["plain"]["speech"] == "test response"
|
||||||
|
|
||||||
|
|
||||||
async def test_template_variables(
|
async def test_template_variables(
|
||||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -103,7 +161,9 @@ async def test_template_variables(
|
|||||||
patch("ollama.AsyncClient.list"),
|
patch("ollama.AsyncClient.list"),
|
||||||
patch(
|
patch(
|
||||||
"ollama.AsyncClient.chat",
|
"ollama.AsyncClient.chat",
|
||||||
return_value={"message": {"role": "assistant", "content": "test response"}},
|
return_value=stream_generator(
|
||||||
|
{"message": {"role": "assistant", "content": "test response"}}
|
||||||
|
),
|
||||||
) as mock_chat,
|
) as mock_chat,
|
||||||
patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user),
|
patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user),
|
||||||
):
|
):
|
||||||
@ -170,26 +230,30 @@ async def test_function_call(
|
|||||||
def completion_result(*args, messages, **kwargs):
|
def completion_result(*args, messages, **kwargs):
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if message["role"] == "tool":
|
if message["role"] == "tool":
|
||||||
return {
|
return stream_generator(
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "I have successfully called the function",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"tool_calls": [
|
|
||||||
{
|
{
|
||||||
"function": {
|
"message": {
|
||||||
"name": "test_tool",
|
"role": "assistant",
|
||||||
"arguments": tool_args,
|
"content": "I have successfully called the function",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
)
|
||||||
|
|
||||||
|
return stream_generator(
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"name": "test_tool",
|
||||||
|
"arguments": tool_args,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"ollama.AsyncClient.chat",
|
"ollama.AsyncClient.chat",
|
||||||
@ -251,26 +315,30 @@ async def test_function_exception(
|
|||||||
def completion_result(*args, messages, **kwargs):
|
def completion_result(*args, messages, **kwargs):
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if message["role"] == "tool":
|
if message["role"] == "tool":
|
||||||
return {
|
return stream_generator(
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "There was an error calling the function",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"tool_calls": [
|
|
||||||
{
|
{
|
||||||
"function": {
|
"message": {
|
||||||
"name": "test_tool",
|
"role": "assistant",
|
||||||
"arguments": {"param1": "test_value"},
|
"content": "There was an error calling the function",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
)
|
||||||
|
|
||||||
|
return stream_generator(
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"name": "test_tool",
|
||||||
|
"arguments": {"param1": "test_value"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"ollama.AsyncClient.chat",
|
"ollama.AsyncClient.chat",
|
||||||
@ -344,7 +412,9 @@ async def test_message_history_trimming(
|
|||||||
def response(*args, **kwargs) -> dict:
|
def response(*args, **kwargs) -> dict:
|
||||||
nonlocal response_idx
|
nonlocal response_idx
|
||||||
response_idx += 1
|
response_idx += 1
|
||||||
return {"message": {"role": "assistant", "content": f"response {response_idx}"}}
|
return stream_generator(
|
||||||
|
{"message": {"role": "assistant", "content": f"response {response_idx}"}}
|
||||||
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"ollama.AsyncClient.chat",
|
"ollama.AsyncClient.chat",
|
||||||
@ -438,11 +508,13 @@ async def test_message_history_unlimited(
|
|||||||
"""Test that message history is not trimmed when max_history = 0."""
|
"""Test that message history is not trimmed when max_history = 0."""
|
||||||
conversation_id = "1234"
|
conversation_id = "1234"
|
||||||
|
|
||||||
|
def stream(*args, **kwargs) -> AsyncGenerator[dict]:
|
||||||
|
return stream_generator(
|
||||||
|
{"message": {"role": "assistant", "content": "test response"}}
|
||||||
|
)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch("ollama.AsyncClient.chat", side_effect=stream) as mock_chat,
|
||||||
"ollama.AsyncClient.chat",
|
|
||||||
return_value={"message": {"role": "assistant", "content": "test response"}},
|
|
||||||
) as mock_chat,
|
|
||||||
):
|
):
|
||||||
hass.config_entries.async_update_entry(
|
hass.config_entries.async_update_entry(
|
||||||
mock_config_entry, options={ollama.CONF_MAX_HISTORY: 0}
|
mock_config_entry, options={ollama.CONF_MAX_HISTORY: 0}
|
||||||
@ -559,7 +631,9 @@ async def test_options(
|
|||||||
"""Test that options are passed correctly to ollama client."""
|
"""Test that options are passed correctly to ollama client."""
|
||||||
with patch(
|
with patch(
|
||||||
"ollama.AsyncClient.chat",
|
"ollama.AsyncClient.chat",
|
||||||
return_value={"message": {"role": "assistant", "content": "test response"}},
|
return_value=stream_generator(
|
||||||
|
{"message": {"role": "assistant", "content": "test response"}}
|
||||||
|
),
|
||||||
) as mock_chat:
|
) as mock_chat:
|
||||||
await conversation.async_converse(
|
await conversation.async_converse(
|
||||||
hass,
|
hass,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user