Compare commits

..

27 Commits

Author SHA1 Message Date
Jeffrey Morgan
5534f2cc6a llm: consider head_dim in llama arch (#5817) 2024-07-20 21:48:12 -04:00
Daniel Hiltgen
d321297d8a Merge pull request #5815 from dhiltgen/win_rocm_gfx_features
Adjust windows ROCm discovery
2024-07-20 16:02:55 -07:00
Daniel Hiltgen
06e5d74e34 Merge pull request #5506 from dhiltgen/sched_tests
Refine scheduler unit tests for reliability
2024-07-20 15:48:39 -07:00
Daniel Hiltgen
5d707e6fd5 Merge pull request #5583 from dhiltgen/integration_improvements
Fix context exhaustion integration test for small gpus
2024-07-20 15:48:21 -07:00
Daniel Hiltgen
283948c83b Adjust windows ROCm discovery
The v5 hip library returns unsupported GPUs which wont enumerate at
inference time in the runner so this makes sure we align discovery.  The
gfx906 cards are no longer supported so we shouldn't compile with that
GPU type as it wont enumerate at runtime.
2024-07-20 15:17:50 -07:00
Jeffrey Morgan
1475eab95f add patch for tekken (#5807) 2024-07-20 13:41:21 -04:00
Jeffrey Morgan
20090f3172 preserve last assistant message (#5802) 2024-07-19 20:19:26 -07:00
Jeffrey Morgan
69a2d4ccff Fix generate test flakyness (#5804) 2024-07-19 19:11:25 -07:00
Josh
e8b954c646 server: validate template (#5734)
add template validation to modelfile
2024-07-19 15:24:29 -07:00
royjhan
c57317cbf0 OpenAI: Function Based Testing (#5752)
* distinguish error forwarding

* more coverage

* rm comment
2024-07-19 11:37:12 -07:00
royjhan
51b2fd299c adjust openai chat msg processing (#5729) 2024-07-19 11:19:20 -07:00
Michael Yang
d0634b1596 Merge pull request #5780 from ollama/mxyng/tools
fix parsing tool calls: break on unexpected eofs
2024-07-18 12:14:10 -07:00
Michael Yang
43606d6d6a fix parsing tool calls 2024-07-18 12:08:11 -07:00
Jeffrey Morgan
70b1010fa5 server: check for empty tools array too (#5779) 2024-07-18 11:44:57 -07:00
Jeffrey Morgan
84e5721f3a always provide content even if empty (#5778) 2024-07-18 11:28:19 -07:00
Jeffrey Morgan
319fb1ce03 server: only parse tool calls if tools are provided (#5771)
* server: only parse tool calls if tools are provided

* still set `resp.Message.Content`
2024-07-18 08:50:23 -07:00
Michael Yang
b255445557 marshal json automatically for some template values (#5758) 2024-07-17 15:35:11 -07:00
Michael Yang
b23424bb3c Merge pull request #5753 from ollama/mxyng/parse-tool-call
parse tool call as individual objects
2024-07-17 11:47:53 -07:00
Michael Yang
5fd6988126 parse tool call as individual objects 2024-07-17 11:19:04 -07:00
Michael Yang
5b82960df8 stub response (#5750) 2024-07-17 10:39:22 -07:00
Michael Yang
cc9a252d8c Merge pull request #5732 from ollama/mxyng/cleanup
remove ToolCall from GenerateResponse
2024-07-17 10:26:54 -07:00
Pákozdi György
d281a6e603 add sidellama link (#5702) 2024-07-17 10:24:44 -07:00
royjhan
154f6f45d4 OpenAI: Support Tools (#5614)
* reopen pr

* tools

* remove tc from stream for now

* ID and Function

* openai expects arguments to be a string (#5739)

* mutually exclusive content and tool calls

* clean up

---------

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
2024-07-16 20:52:59 -07:00
royjhan
0d41623b52 OpenAI: Add Suffix to v1/completions (#5611)
* add suffix

* remove todo

* remove TODO

* add to test

* rm outdated prompt tokens info md

* fix test

* fix test
2024-07-16 20:50:14 -07:00
Michael Yang
c279f96371 remove ToolCall from GenerateResponse 2024-07-16 15:22:49 -07:00
Daniel Hiltgen
73e2c8f68f Fix context exhaustion integration test for small gpus
On the smaller GPUs, the initial model load of llama2 took over 30s (the
default timeout for the DoGenerate helper)
2024-07-09 16:24:14 -07:00
Daniel Hiltgen
f4408219e9 Refine scheduler unit tests for reliability
This breaks up some of the test scenarios to create a
more reliable set of tests, as well as adding a little more
coverage.
2024-07-09 16:00:08 -07:00
27 changed files with 863 additions and 481 deletions

View File

@@ -295,6 +295,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
- [AI Studio](https://github.com/MindWorkAI/AI-Studio)
- [Sidellama](https://github.com/gyopak/sidellama) (browser-based LLM client)
### Terminal

View File

@@ -101,46 +101,29 @@ type ChatRequest struct {
KeepAlive *Duration `json:"keep_alive,omitempty"`
// Tools is an optional list of tools the model has access to.
Tools []Tool `json:"tools,omitempty"`
Tools `json:"tools,omitempty"`
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
}
type Tools []Tool
func (t Tools) String() string {
bts, _ := json.Marshal(t)
return string(bts)
}
// Message is a single message in a chat sequence. The message contains the
// role ("system", "user", or "assistant"), the content and an optional list
// of images.
type Message struct {
Role string `json:"role"`
Content string `json:"content,omitempty"`
Content string `json:"content"`
Images []ImageData `json:"images,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
type ToolCall struct {
Function struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
} `json:"function"`
}
type Tool struct {
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters struct {
Type string `json:"type"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
} `json:"parameters"`
} `json:"function"`
}
func (m *Message) UnmarshalJSON(b []byte) error {
type Alias Message
var a Alias
@@ -153,6 +136,46 @@ func (m *Message) UnmarshalJSON(b []byte) error {
return nil
}
type ToolCall struct {
Function ToolCallFunction `json:"function"`
}
type ToolCallFunction struct {
Name string `json:"name"`
Arguments ToolCallFunctionArguments `json:"arguments"`
}
type ToolCallFunctionArguments map[string]any
func (t *ToolCallFunctionArguments) String() string {
bts, _ := json.Marshal(t)
return string(bts)
}
type Tool struct {
Type string `json:"type"`
Function ToolFunction `json:"function"`
}
type ToolFunction struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters struct {
Type string `json:"type"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
} `json:"parameters"`
}
func (t *ToolFunction) String() string {
bts, _ := json.Marshal(t)
return string(bts)
}
// ChatResponse is the response returned by [Client.Chat]. Its fields are
// similar to [GenerateResponse].
type ChatResponse struct {
@@ -405,9 +428,6 @@ type GenerateResponse struct {
// Response is the textual response itself.
Response string `json:"response"`
// ToolCalls is the list of tools the model wants to call
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
// Done specifies if the response is complete.
Done bool `json:"done"`

View File

@@ -46,13 +46,24 @@ sudo modprobe nvidia_uvm`
## AMD Radeon
Ollama supports the following AMD GPUs:
### Linux Support
| Family | Cards and accelerators |
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` `Vega 56` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `VII` `SSG` |
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` `MI50` |
### Overrides
### Windows Support
With ROCm v6.1, the following GPUs are supported on Windows.
| Family | Cards and accelerators |
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` |
### Overrides on Linux
Ollama leverages the AMD ROCm library, which does not support all AMD GPUs. In
some cases you can force the system to try to use a similar LLVM target that is
close. For example The Radeon RX 5400 is `gfx1034` (also known as 10.3.4)
@@ -63,7 +74,7 @@ would set `HSA_OVERRIDE_GFX_VERSION="10.3.0"` as an environment variable for the
server. If you have an unsupported AMD GPU you can experiment using the list of
supported types below.
At this time, the known supported GPU types are the following LLVM Targets.
At this time, the known supported GPU types on linux are the following LLVM Targets.
This table shows some example GPUs that map to these LLVM targets:
| **LLVM Target** | **An Example GPU** |
|-----------------|---------------------|

View File

@@ -27,11 +27,6 @@ chat_completion = client.chat.completions.create(
],
model='llama3',
)
completion = client.completions.create(
model="llama3",
prompt="Say this is a test"
)
```
### OpenAI JavaScript library
@@ -50,11 +45,6 @@ const chatCompletion = await openai.chat.completions.create({
messages: [{ role: 'user', content: 'Say this is a test' }],
model: 'llama3',
})
const completion = await openai.completions.create({
model: "llama3",
prompt: "Say this is a test.",
})
```
### `curl`
@@ -76,12 +66,6 @@ curl http://localhost:11434/v1/chat/completions \
]
}'
curl http://localhost:11434/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llama3",
"prompt": "Say this is a test"
}'
```
## Endpoints
@@ -119,73 +103,6 @@ curl http://localhost:11434/v1/completions \
- [ ] `user`
- [ ] `n`
### `/v1/completions`
#### Supported features
- [x] Completions
- [x] Streaming
- [x] JSON mode
- [x] Reproducible outputs
- [ ] Logprobs
#### Supported request fields
- [x] `model`
- [x] `prompt`
- [x] `frequency_penalty`
- [x] `presence_penalty`
- [x] `seed`
- [x] `stop`
- [x] `stream`
- [x] `temperature`
- [x] `top_p`
- [x] `max_tokens`
- [x] `suffix`
- [ ] `best_of`
- [ ] `echo`
- [ ] `logit_bias`
- [ ] `user`
- [ ] `n`
#### Notes
- `prompt` currently only accepts a string
### `/v1/completions`
#### Supported features
- [x] Completions
- [x] Streaming
- [x] JSON mode
- [x] Reproducible outputs
- [ ] Logprobs
#### Supported request fields
- [x] `model`
- [x] `prompt`
- [x] `frequency_penalty`
- [x] `presence_penalty`
- [x] `seed`
- [x] `stop`
- [x] `stream`
- [x] `temperature`
- [x] `top_p`
- [x] `max_tokens`
- [ ] `best_of`
- [ ] `echo`
- [ ] `suffix`
- [ ] `logit_bias`
- [ ] `user`
- [ ] `n`
#### Notes
- `prompt` currently only accepts a string
- `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached
## Models
Before using a model, pull it locally `ollama pull`:

View File

@@ -33,9 +33,10 @@ type HipLib struct {
}
func NewHipLib() (*HipLib, error) {
h, err := windows.LoadLibrary("amdhip64.dll")
// At runtime we depend on v6, so discover GPUs with the same library for a consistent set of GPUs
h, err := windows.LoadLibrary("amdhip64_6.dll")
if err != nil {
return nil, fmt.Errorf("unable to load amdhip64.dll: %w", err)
return nil, fmt.Errorf("unable to load amdhip64_6.dll, please make sure to upgrade to the latest amd driver: %w", err)
}
hl := &HipLib{}
hl.dll = h

View File

@@ -92,7 +92,8 @@ func AMDGetGPUInfo() []RocmGPUInfo {
continue
}
if gfxOverride == "" {
if !slices.Contains[[]string, string](supported, gfx) {
// Strip off Target Features when comparing
if !slices.Contains[[]string, string](supported, strings.Split(gfx, ":")[0]) {
slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
// TODO - consider discrete markdown just for ROCM troubleshooting?
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")

View File

@@ -12,7 +12,7 @@ import (
func TestContextExhaustion(t *testing.T) {
// Longer needed for small footprint GPUs
ctx, cancel := context.WithTimeout(context.Background(), 6*time.Minute)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
// Set up the test data
req := api.GenerateRequest{
@@ -25,5 +25,10 @@ func TestContextExhaustion(t *testing.T) {
"num_ctx": 128,
},
}
GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"})
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("PullIfMissing failed: %v", err)
}
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second)
}

View File

@@ -7,8 +7,8 @@ function amdGPUs {
return $env:AMDGPU_TARGETS
}
# Current supported rocblas list from ROCm v6.1.2 on windows
# https://rocm.docs.amd.com/projects/install-on-windows/en/latest/reference/system-requirements.html#windows-supported-gpus
$GPU_LIST = @(
"gfx906:xnack-"
"gfx1030"
"gfx1100"
"gfx1101"

View File

@@ -0,0 +1,43 @@
diff --git a/include/llama.h b/include/llama.h
index bb4b05ba..a92174e0 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -92,6 +92,7 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
+ LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
};
// note: these values should be synchronized with ggml_rope
diff --git a/src/llama.cpp b/src/llama.cpp
index 18364976..435b6fe5 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -5429,6 +5429,12 @@ static void llm_load_vocab(
} else if (
tokenizer_pre == "jais") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS;
+ } else if (
+ tokenizer_pre == "tekken") {
+ vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_TEKKEN;
+ vocab.tokenizer_clean_spaces = false;
+ vocab.tokenizer_ignore_merges = true;
+ vocab.tokenizer_add_bos = true;
} else {
LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
@@ -15448,6 +15454,13 @@ struct llm_tokenizer_bpe {
" ?[^(\\s|.,!?…。,、।۔،)]+",
};
break;
+ case LLAMA_VOCAB_PRE_TYPE_TEKKEN:
+ // original regex from tokenizer.json
+ // "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
+ regex_exprs = {
+ "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+ };
+ break;
default:
// default regex for BPE tokenization pre-processing
regex_exprs = {

View File

@@ -0,0 +1,19 @@
diff --git a/src/llama.cpp b/src/llama.cpp
index 2b9ace28..e60d3d8d 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -6052,10 +6052,10 @@ static bool llm_load_tensors(
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
- layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
- layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
- layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
- layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+ layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head});
+ layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
+ layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
+ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
// optional bias tensors
layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);

View File

@@ -385,8 +385,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
filteredEnv := []string{}
for _, ev := range s.cmd.Env {
if strings.HasPrefix(ev, "CUDA_") ||
strings.HasPrefix(ev, "ROCR_") ||
strings.HasPrefix(ev, "ROCM_") ||
strings.HasPrefix(ev, "HIP_") ||
strings.HasPrefix(ev, "GPU_") ||
strings.HasPrefix(ev, "HSA_") ||
strings.HasPrefix(ev, "GGML_") ||
strings.HasPrefix(ev, "PATH=") ||

View File

@@ -7,6 +7,7 @@ import (
"encoding/json"
"fmt"
"io"
"log/slog"
"math/rand"
"net/http"
"strings"
@@ -29,8 +30,9 @@ type ErrorResponse struct {
}
type Message struct {
Role string `json:"role"`
Content any `json:"content"`
Role string `json:"role"`
Content any `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
type Choice struct {
@@ -78,6 +80,7 @@ type ChatCompletionRequest struct {
PresencePenalty *float64 `json:"presence_penalty_penalty"`
TopP *float64 `json:"top_p"`
ResponseFormat *ResponseFormat `json:"response_format"`
Tools []api.Tool `json:"tools"`
}
type ChatCompletion struct {
@@ -111,6 +114,7 @@ type CompletionRequest struct {
Stream bool `json:"stream"`
Temperature *float32 `json:"temperature"`
TopP float32 `json:"top_p"`
Suffix string `json:"suffix"`
}
type Completion struct {
@@ -132,6 +136,15 @@ type CompletionChunk struct {
SystemFingerprint string `json:"system_fingerprint"`
}
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
}
type Model struct {
Id string `json:"id"`
Object string `json:"object"`
@@ -170,7 +183,31 @@ func NewError(code int, message string) ErrorResponse {
return ErrorResponse{Error{Type: etype, Message: message}}
}
func toolCallId() string {
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, 8)
for i := range b {
b[i] = letterBytes[rand.Intn(len(letterBytes))]
}
return "call_" + strings.ToLower(string(b))
}
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
for i, tc := range r.Message.ToolCalls {
toolCalls[i].ID = toolCallId()
toolCalls[i].Type = "function"
toolCalls[i].Function.Name = tc.Function.Name
args, err := json.Marshal(tc.Function.Arguments)
if err != nil {
slog.Error("could not marshall function arguments to json", "error", err)
continue
}
toolCalls[i].Function.Arguments = string(args)
}
return ChatCompletion{
Id: id,
Object: "chat.completion",
@@ -179,7 +216,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
SystemFingerprint: "fp_ollama",
Choices: []Choice{{
Index: 0,
Message: Message{Role: r.Message.Role, Content: r.Message.Content},
Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
FinishReason: func(reason string) *string {
if len(reason) > 0 {
return &reason
@@ -188,7 +225,6 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
}(r.DoneReason),
}},
Usage: Usage{
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
@@ -234,7 +270,6 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
}(r.DoneReason),
}},
Usage: Usage{
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
@@ -316,7 +351,6 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
case string:
messages = append(messages, api.Message{Role: msg.Role, Content: content})
case []any:
message := api.Message{Role: msg.Role}
for _, c := range content {
data, ok := c.(map[string]any)
if !ok {
@@ -328,7 +362,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
if !ok {
return nil, fmt.Errorf("invalid message format")
}
message.Content = text
messages = append(messages, api.Message{Role: msg.Role, Content: text})
case "image_url":
var url string
if urlMap, ok := data["image_url"].(map[string]any); ok {
@@ -360,14 +394,26 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
if err != nil {
return nil, fmt.Errorf("invalid message format")
}
message.Images = append(message.Images, img)
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
default:
return nil, fmt.Errorf("invalid message format")
}
}
messages = append(messages, message)
default:
return nil, fmt.Errorf("invalid message content type: %T", content)
if msg.ToolCalls == nil {
return nil, fmt.Errorf("invalid message content type: %T", content)
}
toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
for i, tc := range msg.ToolCalls {
toolCalls[i].Function.Name = tc.Function.Name
err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
if err != nil {
return nil, fmt.Errorf("invalid tool call arguments")
}
}
messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
}
}
@@ -425,6 +471,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
Format: format,
Options: options,
Stream: &r.Stream,
Tools: r.Tools,
}, nil
}
@@ -475,6 +522,7 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
Prompt: r.Prompt,
Options: options,
Stream: &r.Stream,
Suffix: r.Suffix,
}, nil
}
@@ -829,6 +877,7 @@ func ChatMiddleware() gin.HandlerFunc {
chatReq, err := fromChatRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {

View File

@@ -20,108 +20,59 @@ const prefix = `data:image/jpeg;base64,`
const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
const imageURL = prefix + image
func TestMiddlewareRequests(t *testing.T) {
func prepareRequest(req *http.Request, body any) {
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
}
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
err := json.Unmarshal(bodyBytes, capturedRequest)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
}
c.Next()
}
}
func TestChatMiddleware(t *testing.T) {
type testCase struct {
Name string
Method string
Path string
Handler func() gin.HandlerFunc
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder)
}
var capturedRequest *http.Request
captureRequestMiddleware := func() gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
capturedRequest = c.Request
c.Next()
}
}
var capturedRequest *api.ChatRequest
testCases := []testCase{
{
Name: "chat handler",
Method: http.MethodPost,
Path: "/api/chat",
Handler: ChatMiddleware,
Name: "chat handler",
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: "Hello"}},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *http.Request) {
var chatReq api.ChatRequest
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
t.Fatal(err)
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.Code)
}
if chatReq.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
if req.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
}
if chatReq.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
if req.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
}
},
},
{
Name: "completions handler",
Method: http.MethodPost,
Path: "/api/generate",
Handler: CompletionsMiddleware,
Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: &temp,
Stop: []string{"\n", "stop"},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var genReq api.GenerateRequest
if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
t.Fatal(err)
}
if genReq.Prompt != "Hello" {
t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
}
if genReq.Options["temperature"] != 1.6 {
t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
}
stopTokens, ok := genReq.Options["stop"].([]any)
if !ok {
t.Fatalf("expected stop tokens to be a list")
}
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
}
},
},
{
Name: "chat handler with image content",
Method: http.MethodPost,
Path: "/api/chat",
Handler: ChatMiddleware,
Name: "chat handler with image content",
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
@@ -134,87 +85,254 @@ func TestMiddlewareRequests(t *testing.T) {
},
},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *http.Request) {
var chatReq api.ChatRequest
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
t.Fatal(err)
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.Code)
}
if chatReq.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
if req.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
}
if chatReq.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
if req.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
}
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
if !bytes.Equal(chatReq.Messages[0].Images[0], img) {
t.Fatalf("expected image encoding, got %s", chatReq.Messages[0].Images[0])
if req.Messages[1].Role != "user" {
t.Fatalf("expected 'user', got %s", req.Messages[1].Role)
}
if !bytes.Equal(req.Messages[1].Images[0], img) {
t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0])
}
},
},
{
Name: "embed handler single input",
Method: http.MethodPost,
Path: "/api/embed",
Handler: EmbeddingsMiddleware,
Name: "chat handler with tools",
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{
{Role: "user", Content: "What's the weather like in Paris Today?"},
{Role: "assistant", ToolCalls: []ToolCall{{
ID: "id",
Type: "function",
Function: struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}{
Name: "get_current_weather",
Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}",
},
}}},
},
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
if resp.Code != 200 {
t.Fatalf("expected 200, got %d", resp.Code)
}
if req.Messages[0].Content != "What's the weather like in Paris Today?" {
t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content)
}
if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" {
t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"])
}
if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" {
t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"])
}
},
},
{
Name: "chat handler error forwarding",
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: 2}},
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid message content type") {
t.Fatalf("error was not forwarded")
}
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/chat", endpoint)
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil)
tc.Setup(t, req)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest, resp)
capturedRequest = nil
})
}
}
func TestCompletionsMiddleware(t *testing.T) {
type testCase struct {
Name string
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder)
}
var capturedRequest *api.GenerateRequest
testCases := []testCase{
{
Name: "completions handler",
Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: &temp,
Stop: []string{"\n", "stop"},
Suffix: "suffix",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
if req.Prompt != "Hello" {
t.Fatalf("expected 'Hello', got %s", req.Prompt)
}
if req.Options["temperature"] != 1.6 {
t.Fatalf("expected 1.6, got %f", req.Options["temperature"])
}
stopTokens, ok := req.Options["stop"].([]any)
if !ok {
t.Fatalf("expected stop tokens to be a list")
}
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
}
if req.Suffix != "suffix" {
t.Fatalf("expected 'suffix', got %s", req.Suffix)
}
},
},
{
Name: "completions handler error forwarding",
Setup: func(t *testing.T, req *http.Request) {
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: nil,
Stop: []int{1, 2},
Suffix: "suffix",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") {
t.Fatalf("error was not forwarded")
}
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/generate", endpoint)
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil)
tc.Setup(t, req)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest, resp)
capturedRequest = nil
})
}
}
func TestEmbeddingsMiddleware(t *testing.T) {
type testCase struct {
Name string
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder)
}
var capturedRequest *api.EmbedRequest
testCases := []testCase{
{
Name: "embed handler single input",
Setup: func(t *testing.T, req *http.Request) {
body := EmbedRequest{
Input: "Hello",
Model: "test-model",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *http.Request) {
var embedReq api.EmbedRequest
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
t.Fatal(err)
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
if req.Input != "Hello" {
t.Fatalf("expected 'Hello', got %s", req.Input)
}
if embedReq.Input != "Hello" {
t.Fatalf("expected 'Hello', got %s", embedReq.Input)
}
if embedReq.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", embedReq.Model)
if req.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", req.Model)
}
},
},
{
Name: "embed handler batch input",
Method: http.MethodPost,
Path: "/api/embed",
Handler: EmbeddingsMiddleware,
Name: "embed handler batch input",
Setup: func(t *testing.T, req *http.Request) {
body := EmbedRequest{
Input: []string{"Hello", "World"},
Model: "test-model",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *http.Request) {
var embedReq api.EmbedRequest
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
t.Fatal(err)
}
input, ok := embedReq.Input.([]any)
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
input, ok := req.Input.([]any)
if !ok {
t.Fatalf("expected input to be a list")
@@ -228,36 +346,52 @@ func TestMiddlewareRequests(t *testing.T) {
t.Fatalf("expected 'World', got %s", input[1])
}
if embedReq.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", embedReq.Model)
if req.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", req.Model)
}
},
},
{
Name: "embed handler error forwarding",
Setup: func(t *testing.T, req *http.Request) {
body := EmbedRequest{
Model: "test-model",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid input") {
t.Fatalf("error was not forwarded")
}
},
},
}
gin.SetMode(gin.TestMode)
router := gin.New()
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/embed", endpoint)
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
router = gin.New()
router.Use(captureRequestMiddleware())
router.Use(tc.Handler())
router.Handle(tc.Method, tc.Path, endpoint)
req, _ := http.NewRequest(tc.Method, tc.Path, nil)
req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil)
if tc.Setup != nil {
tc.Setup(t, req)
}
tc.Setup(t, req)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest)
tc.Expected(t, capturedRequest, resp)
capturedRequest = nil
})
}
}
@@ -275,36 +409,6 @@ func TestMiddlewareResponses(t *testing.T) {
}
testCases := []testCase{
{
Name: "completions handler error forwarding",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
},
Setup: func(t *testing.T, req *http.Request) {
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), `"invalid request"`) {
t.Fatalf("error was not forwarded")
}
},
},
{
Name: "list handler",
Method: http.MethodGet,
@@ -321,8 +425,6 @@ func TestMiddlewareResponses(t *testing.T) {
})
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var listResp ListCompletion
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
t.Fatal(err)
@@ -386,6 +488,8 @@ func TestMiddlewareResponses(t *testing.T) {
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
tc.Expected(t, resp)
})
}

View File

@@ -492,6 +492,12 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
layers = append(layers, baseLayer.Layer)
}
case "license", "template", "system":
if c.Name == "template" {
if _, err := template.Parse(c.Args); err != nil {
return fmt.Errorf("%w: %s", errBadTemplate, err)
}
}
if c.Name != "license" {
// replace
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {

View File

@@ -311,12 +311,14 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
}
var b bytes.Buffer
if err := tmpl.Execute(&b, map[string][]map[string]any{
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
"ToolCalls": {
{
"Function": map[string]any{
"Name": "@@name@@",
"Arguments": "@@arguments@@",
Function: api.ToolCallFunction{
Name: "@@name@@",
Arguments: api.ToolCallFunctionArguments{
"@@argument@@": 1,
},
},
},
},
@@ -324,7 +326,7 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
return nil, false
}
var kv map[string]string
var kv map[string]any
// execute the subtree with placeholders to identify the keys
// trim any commands that might exist in the template
if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
@@ -334,17 +336,19 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
// find the keys that correspond to the name and arguments fields
var name, arguments string
for k, v := range kv {
switch v {
case "@@name@@":
switch v.(type) {
case string:
name = k
case "@@arguments@@":
case map[string]any:
arguments = k
}
}
var objs []map[string]any
for offset := 0; offset < len(s); {
if err := json.NewDecoder(strings.NewReader(s[offset:])).Decode(&objs); errors.Is(err, io.EOF) {
var obj map[string]any
decoder := json.NewDecoder(strings.NewReader(s[offset:]))
if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
break
} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
// skip over any syntax errors
@@ -353,10 +357,11 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
// skip over any unmarshalable types
offset += int(unmarshalType.Offset)
} else if err != nil {
slog.Error("parseToolCalls", "error", err)
return nil, false
} else {
// break when an object is decoded
break
offset += int(decoder.InputOffset())
objs = append(objs, obj)
}
}

View File

@@ -115,11 +115,6 @@ func TestExtractFromZipFile(t *testing.T) {
}
}
type function struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
}
func readFile(t *testing.T, base, name string) *bytes.Buffer {
t.Helper()
@@ -167,6 +162,10 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"llama3-groq-tool-use", `<tool_call>
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
</tool_call>`, true},
}
var tools []api.Tool
@@ -181,18 +180,18 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
calls := []api.ToolCall{
{
Function: function{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]any{
Arguments: api.ToolCallFunctionArguments{
"format": "fahrenheit",
"location": "San Francisco, CA",
},
},
},
{
Function: function{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]any{
Arguments: api.ToolCallFunctionArguments{
"format": "celsius",
"location": "Toronto, Canada",
},

View File

@@ -56,6 +56,7 @@ func init() {
}
var errRequired = errors.New("is required")
var errBadTemplate = errors.New("template error")
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions()
@@ -275,11 +276,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
r.Response = sb.String()
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
r.ToolCalls = toolCalls
r.Response = ""
}
c.JSON(http.StatusOK, r)
return
}
@@ -614,8 +610,11 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
quantization := cmp.Or(r.Quantize, r.Quantization)
if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil {
if errors.Is(err, errBadTemplate) {
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
}
ch <- gin.H{"error": err.Error()}
}
}
}()
if r.Stream != nil && !*r.Stream {
@@ -1201,11 +1200,15 @@ func waitForStream(c *gin.Context, ch chan interface{}) {
return
}
case gin.H:
status, ok := r["status"].(int)
if !ok {
status = http.StatusInternalServerError
}
if errorMsg, ok := r["error"].(string); ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
c.JSON(status, gin.H{"error": errorMsg})
return
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"})
c.JSON(status, gin.H{"error": "unexpected error format in progress response"})
return
}
default:
@@ -1295,7 +1298,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
caps := []Capability{CapabilityCompletion}
if req.Tools != nil {
if len(req.Tools) > 0 {
caps = append(caps, CapabilityTools)
}
@@ -1390,9 +1393,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
resp.Message.Content = sb.String()
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
resp.Message.ToolCalls = toolCalls
resp.Message.Content = ""
if len(req.Tools) > 0 {
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
resp.Message.ToolCalls = toolCalls
resp.Message.Content = ""
}
}
c.JSON(http.StatusOK, resp)

View File

@@ -491,6 +491,42 @@ func TestCreateTemplateSystem(t *testing.T) {
if string(system) != "Say bye!" {
t.Errorf("expected \"Say bye!\", actual %s", system)
}
t.Run("incomplete template", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
})
t.Run("template with unclosed if", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ if .Prompt }}", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
})
t.Run("template with undefined function", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ Prompt }}", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
})
}
func TestCreateLicenses(t *testing.T) {

View File

@@ -73,8 +73,8 @@ func TestGenerateChat(t *testing.T) {
getCpuFn: gpu.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
// add 10ms delay to simulate loading
time.Sleep(10 * time.Millisecond)
// add small delay to simulate loading
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
@@ -371,6 +371,8 @@ func TestGenerate(t *testing.T) {
getCpuFn: gpu.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
// add small delay to simulate loading
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"log/slog"
"os"
"runtime"
"testing"
"time"
@@ -94,7 +95,7 @@ func TestLoad(t *testing.T) {
require.Len(t, s.expiredCh, 1)
}
type bundle struct {
type reqBundle struct {
ctx context.Context //nolint:containedctx
ctxDone func()
srv *mockLlm
@@ -102,13 +103,13 @@ type bundle struct {
ggml *llm.GGML
}
func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
func (scenario *reqBundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
return scenario.srv, nil
}
func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle {
scenario := &bundle{}
scenario.ctx, scenario.ctxDone = context.WithCancel(ctx)
func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64, duration *api.Duration) *reqBundle {
b := &reqBundle{}
b.ctx, b.ctxDone = context.WithCancel(ctx)
t.Helper()
f, err := os.CreateTemp(t.TempDir(), modelName)
@@ -135,124 +136,154 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
fname := f.Name()
model := &Model{Name: modelName, ModelPath: fname}
scenario.ggml, err = llm.LoadModel(model.ModelPath, 0)
b.ggml, err = llm.LoadModel(model.ModelPath, 0)
require.NoError(t, err)
scenario.req = &LlmRequest{
ctx: scenario.ctx,
if duration == nil {
duration = &api.Duration{Duration: 5 * time.Millisecond}
}
b.req = &LlmRequest{
ctx: b.ctx,
model: model,
opts: api.DefaultOptions(),
sessionDuration: &api.Duration{Duration: 5 * time.Millisecond},
sessionDuration: duration,
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
}
scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
return scenario
b.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
return b
}
func TestRequests(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 10*time.Second)
func getGpuFn() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte
return []gpu.GpuInfo{g}
}
func getCpuFn() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "cpu"}
g.TotalMemory = 32 * format.GigaByte
g.FreeMemory = 26 * format.GigaByte
return []gpu.GpuInfo{g}
}
func TestRequestsSameModelSameRequest(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer done()
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
scenario1a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
scenario1b.req.model = scenario1a.req.model
scenario1b.ggml = scenario1a.ggml
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
// simple reload of same model
scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
tmpModel := *scenario1a.req.model
scenario2a.req.model = &tmpModel
scenario2a.ggml = scenario1a.ggml
scenario2a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
// Multiple loaded models
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte)
scenario3c := newScenario(t, ctx, "ollama-model-4a", 30)
scenario3c.req.opts.NumGPU = 0 // CPU load, will be allowed
scenario3d := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte
return []gpu.GpuInfo{g}
}
s.getCpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "cpu"}
g.TotalMemory = 32 * format.GigaByte
g.FreeMemory = 26 * format.GigaByte
return []gpu.GpuInfo{g}
}
s.newServerFn = scenario1a.newServer
slog.Info("scenario1a")
s.pendingReqCh <- scenario1a.req
s.getGpuFn = getGpuFn
s.getCpuFn = getCpuFn
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
b := newScenarioRequest(t, ctx, "ollama-model-1", 11, &api.Duration{Duration: 0})
b.req.model = a.req.model
b.ggml = a.ggml
s.newServerFn = a.newServer
slog.Info("a")
s.pendingReqCh <- a.req
require.Len(t, s.pendingReqCh, 1)
s.Run(ctx)
select {
case resp := <-scenario1a.req.successCh:
require.Equal(t, resp.llama, scenario1a.srv)
case resp := <-a.req.successCh:
require.Equal(t, resp.llama, a.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario1a.req.errCh)
case err := <-scenario1a.req.errCh:
require.Empty(t, a.req.errCh)
case err := <-a.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
}
// Same runner as first request due to not needing a reload
s.newServerFn = scenario1b.newServer
slog.Info("scenario1b")
s.pendingReqCh <- scenario1b.req
s.newServerFn = b.newServer
slog.Info("b")
s.pendingReqCh <- b.req
select {
case resp := <-scenario1b.req.successCh:
require.Equal(t, resp.llama, scenario1a.srv)
case resp := <-b.req.successCh:
require.Equal(t, resp.llama, a.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario1b.req.errCh)
case err := <-scenario1b.req.errCh:
require.Empty(t, b.req.errCh)
case err := <-b.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
}
}
func TestRequestsSimpleReloadSameModel(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getCpuFn = getCpuFn
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
b := newScenarioRequest(t, ctx, "ollama-model-1", 20, &api.Duration{Duration: 5 * time.Millisecond})
tmpModel := *a.req.model
b.req.model = &tmpModel
b.ggml = a.ggml
s.newServerFn = a.newServer
slog.Info("a")
s.pendingReqCh <- a.req
require.Len(t, s.pendingReqCh, 1)
s.Run(ctx)
select {
case resp := <-a.req.successCh:
require.Equal(t, resp.llama, a.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, a.req.errCh)
case err := <-a.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
}
// Trigger a reload
s.newServerFn = scenario2a.newServer
scenario2a.req.model.AdapterPaths = []string{"new"}
slog.Info("scenario2a")
s.pendingReqCh <- scenario2a.req
s.newServerFn = b.newServer
b.req.model.AdapterPaths = []string{"new"}
slog.Info("b")
s.pendingReqCh <- b.req
// finish first two requests, so model can reload
time.Sleep(1 * time.Millisecond)
scenario1a.ctxDone()
scenario1b.ctxDone()
a.ctxDone()
select {
case resp := <-scenario2a.req.successCh:
require.Equal(t, resp.llama, scenario2a.srv)
case resp := <-b.req.successCh:
require.Equal(t, resp.llama, b.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario2a.req.errCh)
case err := <-scenario2a.req.errCh:
require.Empty(t, b.req.errCh)
case err := <-b.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
}
}
func TestRequestsMultipleLoadedModels(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getCpuFn = getCpuFn
// Multiple loaded models
a := newScenarioRequest(t, ctx, "ollama-model-3a", 1*format.GigaByte, nil)
b := newScenarioRequest(t, ctx, "ollama-model-3b", 24*format.GigaByte, nil)
c := newScenarioRequest(t, ctx, "ollama-model-4a", 30, nil)
c.req.opts.NumGPU = 0 // CPU load, will be allowed
d := newScenarioRequest(t, ctx, "ollama-model-3c", 30, nil) // Needs prior unloaded
envconfig.MaxRunners = 1
s.newServerFn = scenario3a.newServer
slog.Info("scenario3a")
s.pendingReqCh <- scenario3a.req
// finish prior request, so new model can load
time.Sleep(1 * time.Millisecond)
scenario2a.ctxDone()
s.newServerFn = a.newServer
slog.Info("a")
s.pendingReqCh <- a.req
s.Run(ctx)
select {
case resp := <-scenario3a.req.successCh:
require.Equal(t, resp.llama, scenario3a.srv)
case resp := <-a.req.successCh:
require.Equal(t, resp.llama, a.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3a.req.errCh)
case err := <-scenario3a.req.errCh:
require.Empty(t, a.req.errCh)
case err := <-a.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
@@ -262,15 +293,15 @@ func TestRequests(t *testing.T) {
s.loadedMu.Unlock()
envconfig.MaxRunners = 0
s.newServerFn = scenario3b.newServer
slog.Info("scenario3b")
s.pendingReqCh <- scenario3b.req
s.newServerFn = b.newServer
slog.Info("b")
s.pendingReqCh <- b.req
select {
case resp := <-scenario3b.req.successCh:
require.Equal(t, resp.llama, scenario3b.srv)
case resp := <-b.req.successCh:
require.Equal(t, resp.llama, b.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3b.req.errCh)
case err := <-scenario3b.req.errCh:
require.Empty(t, b.req.errCh)
case err := <-b.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
@@ -280,15 +311,15 @@ func TestRequests(t *testing.T) {
s.loadedMu.Unlock()
// This is a CPU load with NumGPU = 0 so it should load
s.newServerFn = scenario3c.newServer
slog.Info("scenario3c")
s.pendingReqCh <- scenario3c.req
s.newServerFn = c.newServer
slog.Info("c")
s.pendingReqCh <- c.req
select {
case resp := <-scenario3c.req.successCh:
require.Equal(t, resp.llama, scenario3c.srv)
case resp := <-c.req.successCh:
require.Equal(t, resp.llama, c.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3c.req.errCh)
case err := <-scenario3c.req.errCh:
require.Empty(t, c.req.errCh)
case err := <-c.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
@@ -298,25 +329,25 @@ func TestRequests(t *testing.T) {
s.loadedMu.Unlock()
// Try to load a model that wont fit
s.newServerFn = scenario3d.newServer
slog.Info("scenario3d")
s.newServerFn = d.newServer
slog.Info("d")
s.loadedMu.Lock()
require.Len(t, s.loaded, 3)
s.loadedMu.Unlock()
scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
a.ctxDone() // Won't help since this one isn't big enough to make room
time.Sleep(2 * time.Millisecond)
s.pendingReqCh <- scenario3d.req
s.pendingReqCh <- d.req
// finish prior request, so new model can load
time.Sleep(6 * time.Millisecond)
s.loadedMu.Lock()
require.Len(t, s.loaded, 2)
s.loadedMu.Unlock()
scenario3b.ctxDone()
b.ctxDone()
select {
case resp := <-scenario3d.req.successCh:
require.Equal(t, resp.llama, scenario3d.srv)
case resp := <-d.req.successCh:
require.Equal(t, resp.llama, d.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3d.req.errCh)
require.Empty(t, d.req.errCh)
case <-ctx.Done():
t.Fatal("timeout")
}
@@ -325,30 +356,59 @@ func TestRequests(t *testing.T) {
s.loadedMu.Unlock()
}
func TestRequestsModelTooBigForSystem(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}
g.TotalMemory = 4 * format.MebiByte
g.FreeMemory = 3 * format.MebiByte
return []gpu.GpuInfo{g}
}
s.getCpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "cpu"}
g.TotalMemory = 4 * format.MebiByte
g.FreeMemory = 2 * format.MebiByte
return []gpu.GpuInfo{g}
}
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
s.newServerFn = a.newServer
slog.Info("a")
s.pendingReqCh <- a.req
require.Len(t, s.pendingReqCh, 1)
s.Run(ctx)
select {
case <-a.req.successCh:
if runtime.GOOS == "linux" {
t.Fatal("request should have been rejected with out of space")
}
// else - Darwin and Windows don't reject right now
case err := <-a.req.errCh:
require.Contains(t, err.Error(), "too large")
case <-ctx.Done():
t.Fatal("timeout")
}
}
func TestGetRunner(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer done()
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
scenario1c.req.sessionDuration = &api.Duration{Duration: 0}
a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond})
b := newScenarioRequest(t, ctx, "ollama-model-1b", 10, &api.Duration{Duration: 2 * time.Millisecond})
c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond})
envconfig.MaxQueuedRequests = 1
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte
return []gpu.GpuInfo{g}
}
s.newServerFn = scenario1a.newServer
slog.Info("scenario1a")
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
s.getGpuFn = getGpuFn
s.getCpuFn = getCpuFn
s.newServerFn = a.newServer
slog.Info("a")
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1)
slog.Info("scenario1b")
successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration)
slog.Info("b")
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1)
require.Empty(t, successCh1b)
require.Len(t, errCh1b, 1)
@@ -357,22 +417,24 @@ func TestGetRunner(t *testing.T) {
s.Run(ctx)
select {
case resp := <-successCh1a:
require.Equal(t, resp.llama, scenario1a.srv)
require.Equal(t, resp.llama, a.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, errCh1a)
case err := <-errCh1a:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
}
scenario1a.ctxDone()
a.ctxDone() // Set "a" model to idle so it can unload
s.loadedMu.Lock()
require.Len(t, s.loaded, 1)
s.loadedMu.Unlock()
scenario1c.req.model.ModelPath = "bad path"
slog.Info("scenario1c")
successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration)
c.req.model.ModelPath = "bad path"
slog.Info("c")
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration)
// Starts in pending channel, then should be quickly processsed to return an error
time.Sleep(5 * time.Millisecond)
time.Sleep(20 * time.Millisecond) // Long enough for the "a" model to expire and unload
require.Empty(t, successCh1c)
s.loadedMu.Lock()
require.Empty(t, s.loaded)
@@ -380,7 +442,7 @@ func TestGetRunner(t *testing.T) {
require.Len(t, errCh1c, 1)
err = <-errCh1c
require.Contains(t, err.Error(), "bad path")
scenario1b.ctxDone()
b.ctxDone()
}
// TODO - add one scenario that triggers the bogus finished event with positive ref count
@@ -389,7 +451,7 @@ func TestPrematureExpired(t *testing.T) {
defer done()
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
scenario1a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, nil)
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}
@@ -411,6 +473,8 @@ func TestPrematureExpired(t *testing.T) {
s.loadedMu.Unlock()
slog.Info("sending premature expired event now")
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
case err := <-errCh1a:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
}
@@ -446,6 +510,8 @@ func TestUseLoadedRunner(t *testing.T) {
select {
case success := <-req.successCh:
require.Equal(t, r1, success)
case err := <-req.errCh:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
}
@@ -625,8 +691,7 @@ func TestAlreadyCanceled(t *testing.T) {
defer done()
dctx, done2 := context.WithCancel(ctx)
done2()
scenario1a := newScenario(t, dctx, "ollama-model-1", 10)
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
scenario1a := newScenarioRequest(t, dctx, "ollama-model-1", 10, &api.Duration{Duration: 0})
s := InitScheduler(ctx)
slog.Info("scenario1a")
s.pendingReqCh <- scenario1a.req

View File

@@ -46,7 +46,7 @@ Action: ```json
{{- range .ToolCalls }}
{
"tool_name": "{{ .Function.Name }}",
"parameters": {{ json .Function.Arguments }}
"parameters": {{ .Function.Arguments }}
}
{{- end }}
]```

View File

@@ -17,7 +17,7 @@ If you decide to call functions:
Available functions as JSON spec:
{{- if .Tools }}
{{ json .Tools }}
{{ .Tools }}
{{- end }}<|eot_id|>
{{- end }}
{{- range .Messages }}<|start_header_id|>
@@ -25,7 +25,7 @@ Available functions as JSON spec:
{{- end }}<|end_header_id|>
{{- if .Content }}{{ .Content }}
{{- else if .ToolCalls }} functools[
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}{{ "}" }}
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}{{ "}" }}
{{- end }}]
{{- end }}<|eot_id|>
{{- end }}<|start_header_id|>assistant<|end_header_id|>

View File

@@ -0,0 +1,43 @@
{{- if .Messages }}
{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
{{ .System }}
{{- if .Tools }} You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
<tool_call>
{"name": <function-name>,"arguments": <args-dict>}
</tool_call>
Here are the available tools:
<tools>
{{- range .Tools }} {{ .Function }}
{{- end }} </tools>
{{- end }}
{{- end }}<|eot_id|>
{{- range .Messages }}
{{- if ne .Role "system" }}<|start_header_id|>{{ .Role }}<|end_header_id|>
{{ if eq .Role "user" }}{{ .Content }}
{{- else if eq .Role "assistant" }}
{{- if .Content }}{{ .Content }}
{{- else if .ToolCalls }}<tool_call>
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{- end }}
</tool_call>
{{- end }}
{{- else if eq .Role "tool" }}<tool_response>
{{ .Content }}
</tool_response>
{{- end }}<|eot_id|>
{{- end }}
{{- end }}<|start_header_id|>assistant<|end_header_id|>
{{ else }}
{{ if .System }}<|start_header_id|>system<|end_header_id|>
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
{{ end }}{{ .Response }}
{{- if .Response }}<|eot_id|>
{{- end }}

View File

@@ -0,0 +1,24 @@
<|start_header_id|>system<|end_header_id|>
You are a knowledgable assistant. You can answer questions and perform tasks. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
<tool_call>
{"name": <function-name>,"arguments": <args-dict>}
</tool_call>
Here are the available tools:
<tools> {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}} </tools><|eot_id|><|start_header_id|>user<|end_header_id|>
What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
<tool_call>
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}
</tool_call><|eot_id|><|start_header_id|>tool<|end_header_id|>
<tool_response>
22
</tool_response><|eot_id|><|start_header_id|>assistant<|end_header_id|>
The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>
What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

View File

@@ -1,13 +1,13 @@
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ json $.Tools }}[/AVAILABLE_TOOLS]
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ $.Tools }}[/AVAILABLE_TOOLS]
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
{{ end }}{{ .Content }}[/INST]
{{- else if eq .Role "assistant" }}
{{- if .Content }} {{ .Content }}</s>
{{- else if .ToolCalls }}[TOOL_CALLS] [
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{- end }}]</s>
{{- end }}
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]

View File

@@ -150,9 +150,9 @@ func (t *Template) Vars() []string {
type Values struct {
Messages []api.Message
Tools []api.Tool
Prompt string
Suffix string
api.Tools
Prompt string
Suffix string
// forceLegacy is a flag used to test compatibility with legacy templates
forceLegacy bool
@@ -217,6 +217,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
"System": system,
"Messages": messages,
"Tools": v.Tools,
"Response": "",
})
}
@@ -263,6 +264,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
cut = true
return false
}
return cut
@@ -270,8 +272,9 @@ func (t *Template) Execute(w io.Writer, v Values) error {
tree := parse.Tree{Root: nodes.(*parse.ListNode)}
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
"System": system,
"Prompt": prompt,
"System": system,
"Prompt": prompt,
"Response": response,
}); err != nil {
return err
}

View File

@@ -260,6 +260,26 @@ func TestExecuteWithMessages(t *testing.T) {
Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
},
{
"mistral assistant",
[]template{
{"no response", `[INST] {{ .Prompt }}[/INST] `},
{"response", `[INST] {{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `
{{- range $i, $m := .Messages }}
{{- if eq .Role "user" }}[INST] {{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}{{ end }}
{{- end }}`},
},
Values{
Messages: []api.Message{
{Role: "user", Content: "Hello friend!"},
{Role: "assistant", Content: "Hello human!"},
{Role: "user", Content: "What is your name?"},
{Role: "assistant", Content: "My name is Ollama and I"},
},
},
`[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] My name is Ollama and I`,
},
{
"chatml",
[]template{