Merge branch 'main' into jyan/reord-g

This commit is contained in:
Josh 2024-07-16 16:42:49 -07:00 committed by GitHub
commit d069cf753b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 2358 additions and 266 deletions

View File

@ -126,7 +126,7 @@ jobs:
strategy: strategy:
matrix: matrix:
rocm-version: rocm-version:
- '6.1.1' - '6.1.2'
runs-on: linux runs-on: linux
container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }} container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }}
steps: steps:

View File

@ -2,7 +2,7 @@ ARG GOLANG_VERSION=1.22.1
ARG CMAKE_VERSION=3.22.1 ARG CMAKE_VERSION=3.22.1
# this CUDA_VERSION corresponds with the one specified in docs/gpu.md # this CUDA_VERSION corresponds with the one specified in docs/gpu.md
ARG CUDA_VERSION=11.3.1 ARG CUDA_VERSION=11.3.1
ARG ROCM_VERSION=6.1.1 ARG ROCM_VERSION=6.1.2
# Copy the minimal context we need to run the generate scripts # Copy the minimal context we need to run the generate scripts
FROM scratch AS llm-code FROM scratch AS llm-code

View File

@ -293,6 +293,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS) - [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama) - [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama) - [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
- [AI Studio](https://github.com/MindWorkAI/AI-Studio)
### Terminal ### Terminal

View File

@ -347,7 +347,16 @@ func (c *Client) Heartbeat(ctx context.Context) error {
return nil return nil
} }
// Embeddings generates embeddings from a model. // Embed generates embeddings from a model.
func (c *Client) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) {
var resp EmbedResponse
if err := c.do(ctx, http.MethodPost, "/api/embed", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// Embeddings generates an embedding from a model.
func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) { func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
var resp EmbeddingResponse var resp EmbeddingResponse
if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil { if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {

View File

@ -47,6 +47,9 @@ type GenerateRequest struct {
// Prompt is the textual prompt to send to the model. // Prompt is the textual prompt to send to the model.
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
// Suffix is the text that comes after the inserted text.
Suffix string `json:"suffix"`
// System overrides the model's default system message/prompt. // System overrides the model's default system message/prompt.
System string `json:"system"` System string `json:"system"`
@ -97,6 +100,9 @@ type ChatRequest struct {
// followin the request. // followin the request.
KeepAlive *Duration `json:"keep_alive,omitempty"` KeepAlive *Duration `json:"keep_alive,omitempty"`
// Tools is an optional list of tools the model has access to.
Tools []Tool `json:"tools,omitempty"`
// Options lists model-specific options. // Options lists model-specific options.
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
} }
@ -106,8 +112,45 @@ type ChatRequest struct {
// of images. // of images.
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content,omitempty"`
Images []ImageData `json:"images,omitempty"` Images []ImageData `json:"images,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
type ToolCall struct {
Function struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
} `json:"function"`
}
type Tool struct {
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters struct {
Type string `json:"type"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
} `json:"parameters"`
} `json:"function"`
}
func (m *Message) UnmarshalJSON(b []byte) error {
type Alias Message
var a Alias
if err := json.Unmarshal(b, &a); err != nil {
return err
}
*m = Message(a)
m.Role = strings.ToLower(m.Role)
return nil
} }
// ChatResponse is the response returned by [Client.Chat]. Its fields are // ChatResponse is the response returned by [Client.Chat]. Its fields are
@ -173,6 +216,30 @@ type Runner struct {
NumThread int `json:"num_thread,omitempty"` NumThread int `json:"num_thread,omitempty"`
} }
// EmbedRequest is the request passed to [Client.Embed].
type EmbedRequest struct {
// Model is the model name.
Model string `json:"model"`
// Input is the input to embed.
Input any `json:"input"`
// KeepAlive controls how long the model will stay loaded in memory following
// this request.
KeepAlive *Duration `json:"keep_alive,omitempty"`
Truncate *bool `json:"truncate,omitempty"`
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
}
// EmbedResponse is the response from [Client.Embed].
type EmbedResponse struct {
Model string `json:"model"`
Embeddings [][]float32 `json:"embeddings"`
}
// EmbeddingRequest is the request passed to [Client.Embeddings]. // EmbeddingRequest is the request passed to [Client.Embeddings].
type EmbeddingRequest struct { type EmbeddingRequest struct {
// Model is the model name. // Model is the model name.
@ -221,6 +288,8 @@ type DeleteRequest struct {
type ShowRequest struct { type ShowRequest struct {
Model string `json:"model"` Model string `json:"model"`
System string `json:"system"` System string `json:"system"`
// Template is deprecated
Template string `json:"template"` Template string `json:"template"`
Verbose bool `json:"verbose"` Verbose bool `json:"verbose"`
@ -336,6 +405,9 @@ type GenerateResponse struct {
// Response is the textual response itself. // Response is the textual response itself.
Response string `json:"response"` Response string `json:"response"`
// ToolCalls is the list of tools the model wants to call
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
// Done specifies if the response is complete. // Done specifies if the response is complete.
Done bool `json:"done"` Done bool `json:"done"`

View File

@ -208,3 +208,26 @@ func TestUseMmapFormatParams(t *testing.T) {
}) })
} }
} }
func TestMessage_UnmarshalJSON(t *testing.T) {
tests := []struct {
input string
expected string
}{
{`{"role": "USER", "content": "Hello!"}`, "user"},
{`{"role": "System", "content": "Initialization complete."}`, "system"},
{`{"role": "assistant", "content": "How can I help you?"}`, "assistant"},
{`{"role": "TOOl", "content": "Access granted."}`, "tool"},
}
for _, test := range tests {
var msg Message
if err := json.Unmarshal([]byte(test.input), &msg); err != nil {
t.Errorf("Unexpected error: %v", err)
}
if msg.Role != test.expected {
t.Errorf("role not lowercased: got %v, expected %v", msg.Role, test.expected)
}
}
}

View File

@ -127,6 +127,10 @@ Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\models"
Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\history" Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\history"
; NOTE: if the user has a custom OLLAMA_MODELS it will be preserved ; NOTE: if the user has a custom OLLAMA_MODELS it will be preserved
[InstallDelete]
Type: filesandordirs; Name: "{%TEMP}\ollama*"
Type: filesandordirs; Name: "{%LOCALAPPDATA}\Programs\Ollama"
[Messages] [Messages]
WizardReady=Ollama Windows Preview WizardReady=Ollama Windows Preview
ReadyLabel1=%nLet's get you up and running with your own large language models. ReadyLabel1=%nLet's get you up and running with your own large language models.

View File

@ -843,7 +843,6 @@ type runOptions struct {
WordWrap bool WordWrap bool
Format string Format string
System string System string
Template string
Images []api.ImageData Images []api.ImageData
Options map[string]interface{} Options map[string]interface{}
MultiModal bool MultiModal bool
@ -1037,7 +1036,6 @@ func generate(cmd *cobra.Command, opts runOptions) error {
Images: opts.Images, Images: opts.Images,
Format: opts.Format, Format: opts.Format,
System: opts.System, System: opts.System,
Template: opts.Template,
Options: opts.Options, Options: opts.Options,
KeepAlive: opts.KeepAlive, KeepAlive: opts.KeepAlive,
} }

View File

@ -27,7 +27,6 @@ const (
MultilineNone MultilineState = iota MultilineNone MultilineState = iota
MultilinePrompt MultilinePrompt
MultilineSystem MultilineSystem
MultilineTemplate
) )
func loadModel(cmd *cobra.Command, opts *runOptions) error { func loadModel(cmd *cobra.Command, opts *runOptions) error {
@ -94,7 +93,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, "Available Commands:") fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter") fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter")
fmt.Fprintln(os.Stderr, " /set system <string> Set system message") fmt.Fprintln(os.Stderr, " /set system <string> Set system message")
fmt.Fprintln(os.Stderr, " /set template <string> Set prompt template")
fmt.Fprintln(os.Stderr, " /set history Enable history") fmt.Fprintln(os.Stderr, " /set history Enable history")
fmt.Fprintln(os.Stderr, " /set nohistory Disable history") fmt.Fprintln(os.Stderr, " /set nohistory Disable history")
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap") fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
@ -204,10 +202,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System}) opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
fmt.Println("Set system message.") fmt.Println("Set system message.")
sb.Reset() sb.Reset()
case MultilineTemplate:
opts.Template = sb.String()
fmt.Println("Set prompt template.")
sb.Reset()
} }
multiline = MultilineNone multiline = MultilineNone
@ -326,17 +320,13 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
} }
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", ")) fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
opts.Options[args[2]] = fp[args[2]] opts.Options[args[2]] = fp[args[2]]
case "system", "template": case "system":
if len(args) < 3 { if len(args) < 3 {
usageSet() usageSet()
continue continue
} }
if args[1] == "system" {
multiline = MultilineSystem multiline = MultilineSystem
} else if args[1] == "template" {
multiline = MultilineTemplate
}
line := strings.Join(args[2:], " ") line := strings.Join(args[2:], " ")
line, ok := strings.CutPrefix(line, `"""`) line, ok := strings.CutPrefix(line, `"""`)
@ -356,7 +346,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
continue continue
} }
if args[1] == "system" {
opts.System = sb.String() // for display in modelfile opts.System = sb.String() // for display in modelfile
newMessage := api.Message{Role: "system", Content: sb.String()} newMessage := api.Message{Role: "system", Content: sb.String()}
// Check if the slice is not empty and the last message is from 'system' // Check if the slice is not empty and the last message is from 'system'
@ -368,11 +357,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
} }
fmt.Println("Set system message.") fmt.Println("Set system message.")
sb.Reset() sb.Reset()
} else if args[1] == "template" {
opts.Template = sb.String()
fmt.Println("Set prompt template.")
sb.Reset()
}
sb.Reset() sb.Reset()
continue continue
@ -393,7 +377,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
req := &api.ShowRequest{ req := &api.ShowRequest{
Name: opts.Model, Name: opts.Model,
System: opts.System, System: opts.System,
Template: opts.Template,
Options: opts.Options, Options: opts.Options,
} }
resp, err := client.Show(cmd.Context(), req) resp, err := client.Show(cmd.Context(), req)
@ -437,12 +420,9 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Println("No system message was specified for this model.") fmt.Println("No system message was specified for this model.")
} }
case "template": case "template":
switch { if resp.Template != "" {
case opts.Template != "":
fmt.Println(opts.Template + "\n")
case resp.Template != "":
fmt.Println(resp.Template) fmt.Println(resp.Template)
default: } else {
fmt.Println("No prompt template was specified for this model.") fmt.Println("No prompt template was specified for this model.")
} }
default: default:
@ -536,10 +516,6 @@ func buildModelfile(opts runOptions) string {
fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System) fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System)
} }
if opts.Template != "" {
fmt.Fprintf(&mf, "TEMPLATE \"\"\"%s\"\"\"\n", opts.Template)
}
keys := make([]string, 0) keys := make([]string, 0)
for k := range opts.Options { for k := range opts.Options {
keys = append(keys, k) keys = append(keys, k)

View File

@ -59,7 +59,6 @@ func TestModelfileBuilder(t *testing.T) {
opts := runOptions{ opts := runOptions{
Model: "hork", Model: "hork",
System: "You are part horse and part shark, but all hork. Do horklike things", System: "You are part horse and part shark, but all hork. Do horklike things",
Template: "This is a template.",
Messages: []api.Message{ Messages: []api.Message{
{Role: "user", Content: "Hey there hork!"}, {Role: "user", Content: "Hey there hork!"},
{Role: "assistant", Content: "Yes it is true, I am half horse, half shark."}, {Role: "assistant", Content: "Yes it is true, I am half horse, half shark."},
@ -75,7 +74,6 @@ func TestModelfileBuilder(t *testing.T) {
mf := buildModelfile(opts) mf := buildModelfile(opts)
expectedModelfile := `FROM {{.Model}} expectedModelfile := `FROM {{.Model}}
SYSTEM """{{.System}}""" SYSTEM """{{.System}}"""
TEMPLATE """{{.Template}}"""
PARAMETER penalize_newline false PARAMETER penalize_newline false
PARAMETER seed 42 PARAMETER seed 42
PARAMETER stop [hi there] PARAMETER stop [hi there]
@ -97,7 +95,6 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark."""
mf = buildModelfile(opts) mf = buildModelfile(opts)
expectedModelfile = `FROM {{.ParentModel}} expectedModelfile = `FROM {{.ParentModel}}
SYSTEM """{{.System}}""" SYSTEM """{{.System}}"""
TEMPLATE """{{.Template}}"""
PARAMETER penalize_newline false PARAMETER penalize_newline false
PARAMETER seed 42 PARAMETER seed 42
PARAMETER stop [hi there] PARAMETER stop [hi there]

152
integration/embed_test.go Normal file
View File

@ -0,0 +1,152 @@
//go:build integration
package integration
import (
"context"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func TestAllMiniLMEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
req := api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
}
res, err := embedTestHelper(ctx, t, req)
if err != nil {
t.Fatalf("error: %v", err)
}
if len(res.Embeddings) != 1 {
t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
}
if len(res.Embeddings[0]) != 384 {
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
}
if res.Embeddings[0][0] != 0.010071031 {
t.Fatalf("expected 0.010071031, got %f", res.Embeddings[0][0])
}
}
func TestAllMiniLMBatchEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
req := api.EmbedRequest{
Model: "all-minilm",
Input: []string{"why is the sky blue?", "why is the grass green?"},
}
res, err := embedTestHelper(ctx, t, req)
if err != nil {
t.Fatalf("error: %v", err)
}
if len(res.Embeddings) != 2 {
t.Fatalf("expected 2 embeddings, got %d", len(res.Embeddings))
}
if len(res.Embeddings[0]) != 384 {
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
}
if res.Embeddings[0][0] != 0.010071031 || res.Embeddings[1][0] != -0.009802706 {
t.Fatalf("expected 0.010071031 and -0.009802706, got %f and %f", res.Embeddings[0][0], res.Embeddings[1][0])
}
}
func TestAllMiniLmEmbedTruncate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
truncTrue, truncFalse := true, false
type testReq struct {
Name string
Request api.EmbedRequest
}
reqs := []testReq{
{
Name: "Target Truncation",
Request: api.EmbedRequest{
Model: "all-minilm",
Input: "why",
},
},
{
Name: "Default Truncate",
Request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Options: map[string]any{"num_ctx": 1},
},
},
{
Name: "Explicit Truncate",
Request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 1},
},
},
}
res := make(map[string]*api.EmbedResponse)
for _, req := range reqs {
response, err := embedTestHelper(ctx, t, req.Request)
if err != nil {
t.Fatalf("error: %v", err)
}
res[req.Name] = response
}
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
t.Fatal("expected default request to truncate correctly")
}
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
t.Fatal("expected default request and truncate true request to be the same")
}
// check that truncate set to false returns an error if context length is exceeded
_, err := embedTestHelper(ctx, t, api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 1},
})
if err == nil {
t.Fatal("expected error, got nil")
}
}
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err)
}
response, err := client.Embed(ctx, &req)
if err != nil {
return nil, err
}
return response, nil
}

View File

@ -3188,26 +3188,33 @@ int main(int argc, char **argv) {
prompt = ""; prompt = "";
} }
json image_data; if (prompt.size() == 1) {
if (body.count("image_data") != 0) { prompt = prompt[0];
image_data = body["image_data"];
}
else
{
image_data = "";
} }
// create and queue the task // create and queue the task
const int task_id = llama.queue_tasks.get_new_id(); json responses;
llama.queue_results.add_waiting_task_id(task_id); {
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, true, -1); const int id_task = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(id_task);
llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);
// get the result // get the result
task_result result = llama.queue_results.recv(task_id); task_result result = llama.queue_results.recv(id_task);
llama.queue_results.remove_waiting_task_id(task_id); llama.queue_results.remove_waiting_task_id(id_task);
if (result.error) {
// send the result
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
}
responses = result.result_json.value("results", std::vector<json>{result.result_json});
json embeddings = json::array();
for (auto & elem : responses) {
embeddings.push_back(elem.at("embedding"));
}
// send the result
json embedding_res = json{{"embedding", embeddings}};
return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
}
}); });
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!? // GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?

View File

@ -557,6 +557,7 @@ var ggufKVOrder = map[string][]string{
"tokenizer.ggml.add_bos_token", "tokenizer.ggml.add_bos_token",
"tokenizer.ggml.add_eos_token", "tokenizer.ggml.add_eos_token",
"tokenizer.chat_template", "tokenizer.chat_template",
"bert.pooling_type",
}, },
} }

View File

@ -33,7 +33,7 @@ type LlamaServer interface {
Ping(ctx context.Context) error Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embedding(ctx context.Context, prompt string) ([]float64, error) Embed(ctx context.Context, input []string) ([][]float32, error)
Tokenize(ctx context.Context, content string) ([]int, error) Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error) Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error Close() error
@ -127,7 +127,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
// On linux, over-allocating CPU memory will almost always result in an error // On linux, over-allocating CPU memory will almost always result in an error
if runtime.GOOS == "linux" { if runtime.GOOS == "linux" {
systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize
available := min(systemTotalMemory, systemFreeMemory+systemSwapFreeMemory) available := systemFreeMemory + systemSwapFreeMemory
if systemMemoryRequired > available { if systemMemoryRequired > available {
slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", available, "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "swap", format.HumanBytes2(systemSwapFreeMemory)) slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", available, "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "swap", format.HumanBytes2(systemSwapFreeMemory))
return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available)) return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available))
@ -867,15 +867,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return nil return nil
} }
type EmbeddingRequest struct { type EmbedRequest struct {
Content string `json:"content"` Content []string `json:"content"`
} }
type EmbeddingResponse struct { type EmbedResponse struct {
Embedding []float64 `json:"embedding"` Embedding [][]float32 `json:"embedding"`
} }
func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) { func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) {
if err := s.sem.Acquire(ctx, 1); err != nil { if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err) slog.Error("Failed to acquire semaphore", "error", err)
return nil, err return nil, err
@ -890,7 +890,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
} }
data, err := json.Marshal(TokenizeRequest{Content: prompt}) data, err := json.Marshal(EmbedRequest{Content: input})
if err != nil { if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err) return nil, fmt.Errorf("error marshaling embed data: %w", err)
} }
@ -917,7 +917,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
return nil, fmt.Errorf("%s", body) return nil, fmt.Errorf("%s", body)
} }
var embedding EmbeddingResponse var embedding EmbedResponse
if err := json.Unmarshal(body, &embedding); err != nil { if err := json.Unmarshal(body, &embedding); err != nil {
return nil, fmt.Errorf("unmarshal tokenize response: %w", err) return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
} }

View File

@ -3,11 +3,13 @@ package openai
import ( import (
"bytes" "bytes"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"math/rand" "math/rand"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -28,7 +30,7 @@ type ErrorResponse struct {
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content any `json:"content"`
} }
type Choice struct { type Choice struct {
@ -59,6 +61,11 @@ type ResponseFormat struct {
Type string `json:"type"` Type string `json:"type"`
} }
type EmbedRequest struct {
Input any `json:"input"`
Model string `json:"model"`
}
type ChatCompletionRequest struct { type ChatCompletionRequest struct {
Model string `json:"model"` Model string `json:"model"`
Messages []Message `json:"messages"` Messages []Message `json:"messages"`
@ -132,11 +139,23 @@ type Model struct {
OwnedBy string `json:"owned_by"` OwnedBy string `json:"owned_by"`
} }
type Embedding struct {
Object string `json:"object"`
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
}
type ListCompletion struct { type ListCompletion struct {
Object string `json:"object"` Object string `json:"object"`
Data []Model `json:"data"` Data []Model `json:"data"`
} }
type EmbeddingList struct {
Object string `json:"object"`
Data []Embedding `json:"data"`
Model string `json:"model"`
}
func NewError(code int, message string) ErrorResponse { func NewError(code int, message string) ErrorResponse {
var etype string var etype string
switch code { switch code {
@ -260,6 +279,27 @@ func toListCompletion(r api.ListResponse) ListCompletion {
} }
} }
func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
if r.Embeddings != nil {
var data []Embedding
for i, e := range r.Embeddings {
data = append(data, Embedding{
Object: "embedding",
Embedding: e,
Index: i,
})
}
return EmbeddingList{
Object: "list",
Data: data,
Model: model,
}
}
return EmbeddingList{}
}
func toModel(r api.ShowResponse, m string) Model { func toModel(r api.ShowResponse, m string) Model {
return Model{ return Model{
Id: m, Id: m,
@ -269,10 +309,66 @@ func toModel(r api.ShowResponse, m string) Model {
} }
} }
func fromChatRequest(r ChatCompletionRequest) api.ChatRequest { func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
var messages []api.Message var messages []api.Message
for _, msg := range r.Messages { for _, msg := range r.Messages {
messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content}) switch content := msg.Content.(type) {
case string:
messages = append(messages, api.Message{Role: msg.Role, Content: content})
case []any:
message := api.Message{Role: msg.Role}
for _, c := range content {
data, ok := c.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid message format")
}
switch data["type"] {
case "text":
text, ok := data["text"].(string)
if !ok {
return nil, fmt.Errorf("invalid message format")
}
message.Content = text
case "image_url":
var url string
if urlMap, ok := data["image_url"].(map[string]any); ok {
if url, ok = urlMap["url"].(string); !ok {
return nil, fmt.Errorf("invalid message format")
}
} else {
if url, ok = data["image_url"].(string); !ok {
return nil, fmt.Errorf("invalid message format")
}
}
types := []string{"jpeg", "jpg", "png"}
valid := false
for _, t := range types {
prefix := "data:image/" + t + ";base64,"
if strings.HasPrefix(url, prefix) {
url = strings.TrimPrefix(url, prefix)
valid = true
break
}
}
if !valid {
return nil, fmt.Errorf("invalid image input")
}
img, err := base64.StdEncoding.DecodeString(url)
if err != nil {
return nil, fmt.Errorf("invalid message format")
}
message.Images = append(message.Images, img)
default:
return nil, fmt.Errorf("invalid message format")
}
}
messages = append(messages, message)
default:
return nil, fmt.Errorf("invalid message content type: %T", content)
}
} }
options := make(map[string]interface{}) options := make(map[string]interface{})
@ -323,13 +419,13 @@ func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
format = "json" format = "json"
} }
return api.ChatRequest{ return &api.ChatRequest{
Model: r.Model, Model: r.Model,
Messages: messages, Messages: messages,
Format: format, Format: format,
Options: options, Options: options,
Stream: &r.Stream, Stream: &r.Stream,
} }, nil
} }
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) { func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
@ -407,6 +503,11 @@ type RetrieveWriter struct {
model string model string
} }
type EmbedWriter struct {
BaseWriter
model string
}
func (w *BaseWriter) writeError(code int, data []byte) (int, error) { func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
var serr api.StatusError var serr api.StatusError
err := json.Unmarshal(data, &serr) err := json.Unmarshal(data, &serr)
@ -572,6 +673,33 @@ func (w *RetrieveWriter) Write(data []byte) (int, error) {
return w.writeResponse(data) return w.writeResponse(data)
} }
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
var embedResponse api.EmbedResponse
err := json.Unmarshal(data, &embedResponse)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *EmbedWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(code, data)
}
return w.writeResponse(data)
}
func ListMiddleware() gin.HandlerFunc { func ListMiddleware() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
w := &ListWriter{ w := &ListWriter{
@ -635,6 +763,47 @@ func CompletionsMiddleware() gin.HandlerFunc {
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)), id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
} }
c.Writer = w
c.Next()
}
}
func EmbeddingsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req EmbedRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Input == "" {
req.Input = []string{""}
}
if req.Input == nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
return
}
if v, ok := req.Input.([]any); ok && len(v) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &EmbedWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
model: req.Model,
}
c.Writer = w c.Writer = w
c.Next() c.Next()
@ -656,7 +825,13 @@ func ChatMiddleware() gin.HandlerFunc {
} }
var b bytes.Buffer var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(fromChatRequest(req)); err != nil {
chatReq, err := fromChatRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
}
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return return
} }

View File

@ -2,6 +2,7 @@ package openai
import ( import (
"bytes" "bytes"
"encoding/base64"
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
@ -15,6 +16,10 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
const prefix = `data:image/jpeg;base64,`
const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
const imageURL = prefix + image
func TestMiddlewareRequests(t *testing.T) { func TestMiddlewareRequests(t *testing.T) {
type testCase struct { type testCase struct {
Name string Name string
@ -112,6 +117,122 @@ func TestMiddlewareRequests(t *testing.T) {
} }
}, },
}, },
{
Name: "chat handler with image content",
Method: http.MethodPost,
Path: "/api/chat",
Handler: ChatMiddleware,
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{
{
Role: "user", Content: []map[string]any{
{"type": "text", "text": "Hello"},
{"type": "image_url", "image_url": map[string]string{"url": imageURL}},
},
},
},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var chatReq api.ChatRequest
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
t.Fatal(err)
}
if chatReq.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
}
if chatReq.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
}
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
if !bytes.Equal(chatReq.Messages[0].Images[0], img) {
t.Fatalf("expected image encoding, got %s", chatReq.Messages[0].Images[0])
}
},
},
{
Name: "embed handler single input",
Method: http.MethodPost,
Path: "/api/embed",
Handler: EmbeddingsMiddleware,
Setup: func(t *testing.T, req *http.Request) {
body := EmbedRequest{
Input: "Hello",
Model: "test-model",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var embedReq api.EmbedRequest
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
t.Fatal(err)
}
if embedReq.Input != "Hello" {
t.Fatalf("expected 'Hello', got %s", embedReq.Input)
}
if embedReq.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", embedReq.Model)
}
},
},
{
Name: "embed handler batch input",
Method: http.MethodPost,
Path: "/api/embed",
Handler: EmbeddingsMiddleware,
Setup: func(t *testing.T, req *http.Request) {
body := EmbedRequest{
Input: []string{"Hello", "World"},
Model: "test-model",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var embedReq api.EmbedRequest
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
t.Fatal(err)
}
input, ok := embedReq.Input.([]any)
if !ok {
t.Fatalf("expected input to be a list")
}
if input[0].(string) != "Hello" {
t.Fatalf("expected 'Hello', got %s", input[0])
}
if input[1].(string) != "World" {
t.Fatalf("expected 'World', got %s", input[1])
}
if embedReq.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", embedReq.Model)
}
},
},
} }
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)

View File

@ -34,11 +34,20 @@ import (
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
var errCapabilityCompletion = errors.New("completion") var (
errCapabilities = errors.New("does not support")
errCapabilityCompletion = errors.New("completion")
errCapabilityTools = errors.New("tools")
errCapabilityInsert = errors.New("insert")
)
type Capability string type Capability string
const CapabilityCompletion = Capability("completion") const (
CapabilityCompletion = Capability("completion")
CapabilityTools = Capability("tools")
CapabilityInsert = Capability("insert")
)
type registryOptions struct { type registryOptions struct {
Insecure bool Insecure bool
@ -88,6 +97,15 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok { if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
errs = append(errs, errCapabilityCompletion) errs = append(errs, errCapabilityCompletion)
} }
case CapabilityTools:
if !slices.Contains(m.Template.Vars(), "tools") {
errs = append(errs, errCapabilityTools)
}
case CapabilityInsert:
vars := m.Template.Vars()
if !slices.Contains(vars, "suffix") {
errs = append(errs, errCapabilityInsert)
}
default: default:
slog.Error("unknown capability", "capability", cap) slog.Error("unknown capability", "capability", cap)
return fmt.Errorf("unknown capability: %s", cap) return fmt.Errorf("unknown capability: %s", cap)
@ -95,7 +113,7 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
} }
if err := errors.Join(errs...); err != nil { if err := errors.Join(errs...); err != nil {
return fmt.Errorf("missing capabilities: %w", errors.Join(errs...)) return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
} }
return nil return nil

View File

@ -4,6 +4,7 @@ import (
"archive/zip" "archive/zip"
"bytes" "bytes"
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -12,6 +13,9 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"sort" "sort"
"slices"
"strings"
"text/template/parse"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert" "github.com/ollama/ollama/convert"
@ -322,3 +326,87 @@ func detectContentType(r io.Reader) (string, error) {
return "unknown", nil return "unknown", nil
} }
// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls.
// mxyng: this only really works if the input contains tool calls in some JSON format
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
// create a subtree from the node that ranges over .ToolCalls
tmpl := m.Template.Subtree(func(n parse.Node) bool {
if t, ok := n.(*parse.RangeNode); ok {
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
}
return false
})
if tmpl == nil {
return nil, false
}
var b bytes.Buffer
if err := tmpl.Execute(&b, map[string][]map[string]any{
"ToolCalls": {
{
"Function": map[string]any{
"Name": "@@name@@",
"Arguments": "@@arguments@@",
},
},
},
}); err != nil {
return nil, false
}
var kv map[string]string
// execute the subtree with placeholders to identify the keys
// trim any commands that might exist in the template
if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
return nil, false
}
// find the keys that correspond to the name and arguments fields
var name, arguments string
for k, v := range kv {
switch v {
case "@@name@@":
name = k
case "@@arguments@@":
arguments = k
}
}
var objs []map[string]any
for offset := 0; offset < len(s); {
if err := json.NewDecoder(strings.NewReader(s[offset:])).Decode(&objs); errors.Is(err, io.EOF) {
break
} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
// skip over any syntax errors
offset += int(syntax.Offset)
} else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) {
// skip over any unmarshalable types
offset += int(unmarshalType.Offset)
} else if err != nil {
return nil, false
} else {
// break when an object is decoded
break
}
}
var toolCalls []api.ToolCall
for _, kv := range objs {
var call api.ToolCall
for k, v := range kv {
switch k {
case name:
call.Function.Name = v.(string)
case arguments:
call.Function.Arguments = v.(map[string]any)
}
}
toolCalls = append(toolCalls, call)
}
return toolCalls, len(toolCalls) > 0
}

View File

@ -3,7 +3,9 @@ package server
import ( import (
"archive/zip" "archive/zip"
"bytes" "bytes"
"encoding/json"
"errors" "errors"
"fmt"
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
@ -11,7 +13,9 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
) )
func createZipFile(t *testing.T, name string) *os.File { func createZipFile(t *testing.T, name string) *os.File {
@ -110,3 +114,123 @@ func TestExtractFromZipFile(t *testing.T) {
}) })
} }
} }
type function struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
}
func readFile(t *testing.T, base, name string) *bytes.Buffer {
t.Helper()
bts, err := os.ReadFile(filepath.Join(base, name))
if err != nil {
t.Fatal(err)
}
return bytes.NewBuffer(bts)
}
func TestExecuteWithTools(t *testing.T) {
p := filepath.Join("testdata", "tools")
cases := []struct {
model string
output string
ok bool
}{
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"command-r-plus", "Action: ```json" + `
[
{
"tool_name": "get_current_weather",
"parameters": {
"format": "fahrenheit",
"location": "San Francisco, CA"
}
},
{
"tool_name": "get_current_weather",
"parameters": {
"format": "celsius",
"location": "Toronto, Canada"
}
}
]
` + "```", true},
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
}
var tools []api.Tool
if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil {
t.Fatal(err)
}
var messages []api.Message
if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil {
t.Fatal(err)
}
calls := []api.ToolCall{
{
Function: function{
Name: "get_current_weather",
Arguments: map[string]any{
"format": "fahrenheit",
"location": "San Francisco, CA",
},
},
},
{
Function: function{
Name: "get_current_weather",
Arguments: map[string]any{
"format": "celsius",
"location": "Toronto, Canada",
},
},
},
}
for _, tt := range cases {
t.Run(tt.model, func(t *testing.T) {
tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
if err != nil {
t.Fatal(err)
}
t.Run("template", func(t *testing.T) {
var actual bytes.Buffer
if err := tmpl.Execute(&actual, template.Values{Tools: tools, Messages: messages}); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("parse", func(t *testing.T) {
m := &Model{Template: tmpl}
actual, ok := m.parseToolCalls(tt.output)
if ok != tt.ok {
t.Fatalf("expected %t, got %t", tt.ok, ok)
}
if tt.ok {
if diff := cmp.Diff(actual, calls); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
})
})
}
}

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"context" "context"
"log/slog" "log/slog"
"slices"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
@ -16,29 +15,21 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn. // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages // latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) { func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
// pull out any system messages which should always be included in the prompt
var system []api.Message var system []api.Message
msgs = slices.DeleteFunc(msgs, func(m api.Message) bool {
if m.Role == "system" {
system = append(system, m)
return true
}
return false
})
if len(system) == 0 && m.System != "" {
// add model system prompt since it wasn't provided
system = append(system, api.Message{Role: "system", Content: m.System})
}
// always include the last message // always include the last message
n := len(msgs) - 1 n := len(msgs) - 1
// in reverse, find all messages that fit into context window // in reverse, find all messages that fit into context window
for i := n - 1; i >= 0; i-- { for i := n - 1; i >= 0; i-- {
system = make([]api.Message, 0)
for j := range i {
if msgs[j].Role == "system" {
system = append(system, msgs[j])
}
}
var b bytes.Buffer var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil { if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
return "", nil, err return "", nil, err
} }
@ -66,7 +57,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
// truncate any messages that do not fit into the context window // truncate any messages that do not fit into the context window
var b bytes.Buffer var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil { if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil {
return "", nil, err return "", nil, err
} }

View File

@ -3,21 +3,13 @@ package server
import ( import (
"bytes" "bytes"
"context" "context"
"strings"
"testing" "testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
) )
func tokenize(_ context.Context, s string) (tokens []int, err error) {
for range strings.Fields(s) {
tokens = append(tokens, len(tokens))
}
return
}
func TestChatPrompt(t *testing.T) { func TestChatPrompt(t *testing.T) {
type expect struct { type expect struct {
prompt string prompt string
@ -164,6 +156,19 @@ func TestChatPrompt(t *testing.T) {
prompt: "You are the Test Who Lived. You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ", prompt: "You are the Test Who Lived. You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
}, },
}, },
{
name: "out of order system",
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "system", Content: "You are the Test Who Lived."},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
},
},
} }
tmpl, err := template.Parse(` tmpl, err := template.Parse(`
@ -178,13 +183,13 @@ func TestChatPrompt(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}} model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}} opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs) prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if tt.prompt != prompt { if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
t.Errorf("expected %q, got %q", tt.prompt, prompt) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
if len(images) != len(tt.images) { if len(images) != len(tt.images) {

View File

@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"math"
"net" "net"
"net/http" "net/http"
"net/netip" "net/netip"
@ -102,6 +103,7 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
} }
func (s *Server) GenerateHandler(c *gin.Context) { func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.GenerateRequest var req api.GenerateRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
@ -120,6 +122,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
caps := []Capability{CapabilityCompletion} caps := []Capability{CapabilityCompletion}
if req.Suffix != "" {
caps = append(caps, CapabilityInsert)
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) { if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
@ -129,6 +135,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
checkpointLoaded := time.Now()
if req.Prompt == "" { if req.Prompt == "" {
c.JSON(http.StatusOK, api.GenerateResponse{ c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model, Model: req.Model,
@ -146,19 +154,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
prompt := req.Prompt prompt := req.Prompt
if !req.Raw { if !req.Raw {
var msgs []api.Message
if req.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
} else if m.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: m.System})
}
for _, i := range images {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
}
msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
tmpl := m.Template tmpl := m.Template
if req.Template != "" { if req.Template != "" {
tmpl, err = template.Parse(req.Template) tmpl, err = template.Parse(req.Template)
@ -179,7 +174,26 @@ func (s *Server) GenerateHandler(c *gin.Context) {
b.WriteString(s) b.WriteString(s)
} }
if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil { var values template.Values
if req.Suffix != "" {
values.Prompt = prompt
values.Suffix = req.Suffix
} else {
var msgs []api.Message
if req.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
} else if m.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: m.System})
}
for _, i := range images {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
}
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
}
if err := tmpl.Execute(&b, values); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
@ -191,26 +205,48 @@ func (s *Server) GenerateHandler(c *gin.Context) {
ch := make(chan any) ch := make(chan any)
go func() { go func() {
// TODO (jmorganca): avoid building the response twice both here and below
var sb strings.Builder
defer close(ch) defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Images: images, Images: images,
Format: req.Format, Format: req.Format,
Options: opts, Options: opts,
}, func(r llm.CompletionResponse) { }, func(cr llm.CompletionResponse) {
ch <- api.GenerateResponse{ res := api.GenerateResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Response: r.Content, Response: cr.Content,
Done: r.Done, Done: cr.Done,
DoneReason: r.DoneReason, DoneReason: cr.DoneReason,
Metrics: api.Metrics{ Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount, PromptEvalCount: cr.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration, PromptEvalDuration: cr.PromptEvalDuration,
EvalCount: r.EvalCount, EvalCount: cr.EvalCount,
EvalDuration: r.EvalDuration, EvalDuration: cr.EvalDuration,
}, },
} }
if _, err := sb.WriteString(cr.Content); err != nil {
ch <- gin.H{"error": err.Error()}
}
if cr.Done {
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
res.Context = append(req.Context, tokens...)
}
}
ch <- res
}); err != nil { }); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
@ -239,6 +275,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
r.Response = sb.String() r.Response = sb.String()
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
r.ToolCalls = toolCalls
r.Response = ""
}
c.JSON(http.StatusOK, r) c.JSON(http.StatusOK, r)
return return
} }
@ -246,6 +287,121 @@ func (s *Server) GenerateHandler(c *gin.Context) {
streamResponse(c, ch) streamResponse(c, ch)
} }
func (s *Server) EmbedHandler(c *gin.Context) {
var req api.EmbedRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
truncate := true
if req.Truncate != nil && !*req.Truncate {
truncate = false
}
var input []string
switch i := req.Input.(type) {
case string:
if len(i) > 0 {
input = append(input, i)
}
case []any:
for _, v := range i {
if _, ok := v.(string); !ok {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return
}
input = append(input, v.(string))
}
default:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return
}
if len(input) == 0 {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
}
kvData, err := getKVData(m.ModelPath, false)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
for i, s := range input {
tokens, err := r.Tokenize(c.Request.Context(), s)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if len(tokens) > ctxLen {
if !truncate {
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
return
}
tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
input[i] = s
}
embeddings, err := r.Embed(c.Request.Context(), input)
if err != nil {
slog.Error("embedding generation failed", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return
}
for i, e := range embeddings {
embeddings[i] = normalize(e)
}
resp := api.EmbedResponse{
Model: req.Model,
Embeddings: embeddings,
}
c.JSON(http.StatusOK, resp)
}
func normalize(vec []float32) []float32 {
var sum float32
for _, v := range vec {
sum += v * v
}
norm := float32(0.0)
if sum > 0 {
norm = float32(1.0 / math.Sqrt(float64(sum)))
}
for i := range vec {
vec[i] *= norm
}
return vec
}
func (s *Server) EmbeddingsHandler(c *gin.Context) { func (s *Server) EmbeddingsHandler(c *gin.Context) {
var req api.EmbeddingRequest var req api.EmbeddingRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
@ -268,14 +424,24 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return return
} }
embedding, err := r.Embedding(c.Request.Context(), req.Prompt) embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return return
} }
c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: embedding}) embedding := make([]float64, len(embeddings[0]))
for i, v := range embeddings[0] {
embedding[i] = float64(v)
}
resp := api.EmbeddingResponse{
Embedding: embedding,
}
c.JSON(http.StatusOK, resp)
} }
func (s *Server) PullModelHandler(c *gin.Context) { func (s *Server) PullModelHandler(c *gin.Context) {
@ -549,13 +715,6 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
m.System = req.System m.System = req.System
} }
if req.Template != "" {
m.Template, err = template.Parse(req.Template)
if err != nil {
return nil, err
}
}
msgs := make([]api.Message, len(m.Messages)) msgs := make([]api.Message, len(m.Messages))
for i, msg := range m.Messages { for i, msg := range m.Messages {
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content} msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
@ -901,6 +1060,7 @@ func (s *Server) GenerateRoutes() http.Handler {
r.POST("/api/pull", s.PullModelHandler) r.POST("/api/pull", s.PullModelHandler)
r.POST("/api/generate", s.GenerateHandler) r.POST("/api/generate", s.GenerateHandler)
r.POST("/api/chat", s.ChatHandler) r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embed", s.EmbedHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler) r.POST("/api/embeddings", s.EmbeddingsHandler)
r.POST("/api/create", s.CreateModelHandler) r.POST("/api/create", s.CreateModelHandler)
r.POST("/api/push", s.PushModelHandler) r.POST("/api/push", s.PushModelHandler)
@ -914,6 +1074,7 @@ func (s *Server) GenerateRoutes() http.Handler {
// Compatibility endpoints // Compatibility endpoints
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler) r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler) r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler) r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler) r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
@ -1122,6 +1283,8 @@ func (s *Server) ProcessHandler(c *gin.Context) {
} }
func (s *Server) ChatHandler(c *gin.Context) { func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.ChatRequest var req api.ChatRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
@ -1132,6 +1295,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
caps := []Capability{CapabilityCompletion} caps := []Capability{CapabilityCompletion}
if req.Tools != nil {
caps = append(caps, CapabilityTools)
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) { if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
@ -1141,6 +1308,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
checkpointLoaded := time.Now()
if len(req.Messages) == 0 { if len(req.Messages) == 0 {
c.JSON(http.StatusOK, api.ChatResponse{ c.JSON(http.StatusOK, api.ChatResponse{
Model: req.Model, Model: req.Model,
@ -1152,7 +1321,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages) if req.Messages[0].Role != "system" && m.System != "" {
req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...)
}
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages, req.Tools)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@ -1169,7 +1342,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
Format: req.Format, Format: req.Format,
Options: opts, Options: opts,
}, func(r llm.CompletionResponse) { }, func(r llm.CompletionResponse) {
ch <- api.ChatResponse{ res := api.ChatResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content}, Message: api.Message{Role: "assistant", Content: r.Content},
@ -1182,19 +1355,26 @@ func (s *Server) ChatHandler(c *gin.Context) {
EvalDuration: r.EvalDuration, EvalDuration: r.EvalDuration,
}, },
} }
if r.Done {
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
ch <- res
}); err != nil { }); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
if req.Stream != nil && !*req.Stream { if req.Stream != nil && !*req.Stream {
var r api.ChatResponse var resp api.ChatResponse
var sb strings.Builder var sb strings.Builder
for rr := range ch { for rr := range ch {
switch t := rr.(type) { switch t := rr.(type) {
case api.ChatResponse: case api.ChatResponse:
sb.WriteString(t.Message.Content) sb.WriteString(t.Message.Content)
r = t resp = t
case gin.H: case gin.H:
msg, ok := t["error"].(string) msg, ok := t["error"].(string)
if !ok { if !ok {
@ -1209,8 +1389,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
} }
r.Message.Content = sb.String() resp.Message.Content = sb.String()
c.JSON(http.StatusOK, r) if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
resp.Message.ToolCalls = toolCalls
resp.Message.Content = ""
}
c.JSON(http.StatusOK, resp)
return return
} }
@ -1219,7 +1404,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
func handleScheduleError(c *gin.Context, name string, err error) { func handleScheduleError(c *gin.Context, name string, err error) {
switch { switch {
case errors.Is(err, errRequired): case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
case errors.Is(err, context.Canceled): case errors.Is(err, context.Canceled):
c.JSON(499, gin.H{"error": "request canceled"}) c.JSON(499, gin.H{"error": "request canceled"})

View File

@ -85,6 +85,8 @@ func checkFileExists(t *testing.T, p string, expect []string) {
} }
func TestCreateFromBin(t *testing.T) { func TestCreateFromBin(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@ -111,6 +113,8 @@ func TestCreateFromBin(t *testing.T) {
} }
func TestCreateFromModel(t *testing.T) { func TestCreateFromModel(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@ -152,6 +156,8 @@ func TestCreateFromModel(t *testing.T) {
} }
func TestCreateRemovesLayers(t *testing.T) { func TestCreateRemovesLayers(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@ -199,6 +205,8 @@ func TestCreateRemovesLayers(t *testing.T) {
} }
func TestCreateUnsetsSystem(t *testing.T) { func TestCreateUnsetsSystem(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@ -255,6 +263,8 @@ func TestCreateUnsetsSystem(t *testing.T) {
} }
func TestCreateMergeParameters(t *testing.T) { func TestCreateMergeParameters(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@ -358,6 +368,8 @@ func TestCreateMergeParameters(t *testing.T) {
} }
func TestCreateReplacesMessages(t *testing.T) { func TestCreateReplacesMessages(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@ -434,6 +446,8 @@ func TestCreateReplacesMessages(t *testing.T) {
} }
func TestCreateTemplateSystem(t *testing.T) { func TestCreateTemplateSystem(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@ -480,6 +494,8 @@ func TestCreateTemplateSystem(t *testing.T) {
} }
func TestCreateLicenses(t *testing.T) { func TestCreateLicenses(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@ -526,6 +542,8 @@ func TestCreateLicenses(t *testing.T) {
} }
func TestCreateDetectTemplate(t *testing.T) { func TestCreateDetectTemplate(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()

View File

@ -8,12 +8,15 @@ import (
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
func TestDelete(t *testing.T) { func TestDelete(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@ -77,6 +80,8 @@ func TestDelete(t *testing.T) {
} }
func TestDeleteDuplicateLayers(t *testing.T) { func TestDeleteDuplicateLayers(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
var s Server var s Server

View File

@ -0,0 +1,712 @@
package server
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
)
type mockRunner struct {
llm.LlamaServer
// CompletionRequest is only valid until the next call to Completion
llm.CompletionRequest
llm.CompletionResponse
}
func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
m.CompletionRequest = r
fn(m.CompletionResponse)
return nil
}
func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) {
for range strings.Fields(s) {
tokens = append(tokens, len(tokens))
}
return
}
func newMockServer(mock *mockRunner) func(gpu.GpuInfoList, string, *llm.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
return func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, projectors, system []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
return mock, nil
}
}
func TestGenerateChat(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: "stop",
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: gpu.GetGPUInfo,
getCpuFn: gpu.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
// add 10ms delay to simulate loading
time.Sleep(10 * time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
},
},
}
go s.sched.Run(context.TODO())
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test",
Modelfile: fmt.Sprintf(`FROM %s
TEMPLATE """
{{- if .System }}System: {{ .System }} {{ end }}
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
`, createBinFile(t, llm.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []llm.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("missing body", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, nil)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing model", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing capabilities chat", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "bert",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
"general.architecture": "bert",
"bert.pooling_type": uint32(0),
}, []llm.Tensor{})),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "bert",
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support chat"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("load model", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
var actual api.ChatResponse
if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
t.Fatal(err)
}
if actual.Model != "test" {
t.Errorf("expected model test, got %s", actual.Model)
}
if !actual.Done {
t.Errorf("expected done true, got false")
}
if actual.DoneReason != "load" {
t.Errorf("expected done reason load, got %s", actual.DoneReason)
}
})
checkChatResponse := func(t *testing.T, body io.Reader, model, content string) {
t.Helper()
var actual api.ChatResponse
if err := json.NewDecoder(body).Decode(&actual); err != nil {
t.Fatal(err)
}
if actual.Model != model {
t.Errorf("expected model test, got %s", actual.Model)
}
if !actual.Done {
t.Errorf("expected done false, got true")
}
if actual.DoneReason != "stop" {
t.Errorf("expected done reason stop, got %s", actual.DoneReason)
}
if diff := cmp.Diff(actual.Message, api.Message{
Role: "assistant",
Content: content,
}); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
if actual.PromptEvalCount == 0 {
t.Errorf("expected prompt eval count > 0, got 0")
}
if actual.PromptEvalDuration == 0 {
t.Errorf("expected prompt eval duration > 0, got 0")
}
if actual.EvalCount == 0 {
t.Errorf("expected eval count > 0, got 0")
}
if actual.EvalDuration == 0 {
t.Errorf("expected eval duration > 0, got 0")
}
if actual.LoadDuration == 0 {
t.Errorf("expected load duration > 0, got 0")
}
if actual.TotalDuration == 0 {
t.Errorf("expected total duration > 0, got 0")
}
}
mock.CompletionResponse.Content = "Hi!"
t.Run("messages", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test",
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkChatResponse(t, w.Body, "test", "Hi!")
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test-system",
Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("messages with model system", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-system",
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkChatResponse(t, w.Body, "test-system", "Hi!")
})
mock.CompletionResponse.Content = "Abra kadabra!"
t.Run("messages with system", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-system",
Messages: []api.Message{
{Role: "system", Content: "You can perform magic tricks."},
{Role: "user", Content: "Hello!"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
})
t.Run("messages with interleaved system", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-system",
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
{Role: "assistant", Content: "I can help you with that."},
{Role: "system", Content: "You can perform magic tricks."},
{Role: "user", Content: "Help me write tests."},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
})
}
func TestGenerate(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: "stop",
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: gpu.GetGPUInfo,
getCpuFn: gpu.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
req.successCh <- &runnerRef{
llama: &mock,
}
},
},
}
go s.sched.Run(context.TODO())
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test",
Modelfile: fmt.Sprintf(`FROM %s
TEMPLATE """
{{- if .System }}System: {{ .System }} {{ end }}
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
`, createBinFile(t, llm.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []llm.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("missing body", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, nil)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing model", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing capabilities generate", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "bert",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
"general.architecture": "bert",
"bert.pooling_type": uint32(0),
}, []llm.Tensor{})),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
w = createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "bert",
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support generate"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing capabilities suffix", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "def add(",
Suffix: " return c",
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("load model", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
var actual api.GenerateResponse
if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
t.Fatal(err)
}
if actual.Model != "test" {
t.Errorf("expected model test, got %s", actual.Model)
}
if !actual.Done {
t.Errorf("expected done true, got false")
}
if actual.DoneReason != "load" {
t.Errorf("expected done reason load, got %s", actual.DoneReason)
}
})
checkGenerateResponse := func(t *testing.T, body io.Reader, model, content string) {
t.Helper()
var actual api.GenerateResponse
if err := json.NewDecoder(body).Decode(&actual); err != nil {
t.Fatal(err)
}
if actual.Model != model {
t.Errorf("expected model test, got %s", actual.Model)
}
if !actual.Done {
t.Errorf("expected done false, got true")
}
if actual.DoneReason != "stop" {
t.Errorf("expected done reason stop, got %s", actual.DoneReason)
}
if actual.Response != content {
t.Errorf("expected response %s, got %s", content, actual.Response)
}
if actual.Context == nil {
t.Errorf("expected context not nil")
}
if actual.PromptEvalCount == 0 {
t.Errorf("expected prompt eval count > 0, got 0")
}
if actual.PromptEvalDuration == 0 {
t.Errorf("expected prompt eval duration > 0, got 0")
}
if actual.EvalCount == 0 {
t.Errorf("expected eval count > 0, got 0")
}
if actual.EvalDuration == 0 {
t.Errorf("expected eval duration > 0, got 0")
}
if actual.LoadDuration == 0 {
t.Errorf("expected load duration > 0, got 0")
}
if actual.TotalDuration == 0 {
t.Errorf("expected total duration > 0, got 0")
}
}
mock.CompletionResponse.Content = "Hi!"
t.Run("prompt", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "Hello!",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkGenerateResponse(t, w.Body, "test", "Hi!")
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test-system",
Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("prompt with model system", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-system",
Prompt: "Hello!",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkGenerateResponse(t, w.Body, "test-system", "Hi!")
})
mock.CompletionResponse.Content = "Abra kadabra!"
t.Run("prompt with system", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-system",
Prompt: "Hello!",
System: "You can perform magic tricks.",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
})
t.Run("prompt with template", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-system",
Prompt: "Help me write tests.",
System: "You can perform magic tricks.",
Template: `{{- if .System }}{{ .System }} {{ end }}
{{- if .Prompt }}### USER {{ .Prompt }} {{ end }}
{{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`,
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test-suffix",
Modelfile: `FROM test
TEMPLATE """{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
{{- else }}{{ .Prompt }}
{{- end }}"""`,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("prompt with suffix", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-suffix",
Prompt: "def add(",
Suffix: " return c",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("prompt without suffix", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-suffix",
Prompt: "def add(",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("raw", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-system",
Prompt: "Help me write tests.",
Raw: true,
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}

View File

@ -7,11 +7,14 @@ import (
"slices" "slices"
"testing" "testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
) )
func TestList(t *testing.T) { func TestList(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("OLLAMA_MODELS", t.TempDir()) t.Setenv("OLLAMA_MODELS", t.TempDir())
envconfig.LoadConfig() envconfig.LoadConfig()

View File

@ -7,6 +7,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"math"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
@ -272,6 +273,77 @@ func Test_Routes(t *testing.T) {
assert.Equal(t, "library", retrieveResp.OwnedBy) assert.Equal(t, "library", retrieveResp.OwnedBy)
}, },
}, },
{
Name: "Embed Handler Empty Input",
Method: http.MethodPost,
Path: "/api/embed",
Setup: func(t *testing.T, req *http.Request) {
embedReq := api.EmbedRequest{
Model: "t-bone",
Input: "",
}
jsonData, err := json.Marshal(embedReq)
require.NoError(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
if contentType != "application/json; charset=utf-8" {
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
var embedResp api.EmbedResponse
err = json.Unmarshal(body, &embedResp)
if err != nil {
t.Fatal(err)
}
if embedResp.Model != "t-bone" {
t.Fatalf("expected model t-bone, got %s", embedResp.Model)
}
if embedResp.Embeddings == nil {
t.Fatalf("expected embeddings to not be nil, got %v", embedResp.Embeddings)
}
if len(embedResp.Embeddings) != 0 {
t.Fatalf("expected embeddings to be empty, got %v", embedResp.Embeddings)
}
},
},
{
Name: "Embed Handler Invalid Input",
Method: http.MethodPost,
Path: "/api/embed",
Setup: func(t *testing.T, req *http.Request) {
embedReq := api.EmbedRequest{
Model: "t-bone",
Input: 2,
}
jsonData, err := json.Marshal(embedReq)
require.NoError(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
if contentType != "application/json; charset=utf-8" {
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
}
_, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected status code 400, got %d", resp.StatusCode)
}
},
},
} }
t.Setenv("OLLAMA_MODELS", t.TempDir()) t.Setenv("OLLAMA_MODELS", t.TempDir())
@ -420,3 +492,38 @@ func TestShow(t *testing.T) {
t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"]) t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"])
} }
} }
func TestNormalize(t *testing.T) {
type testCase struct {
input []float32
}
testCases := []testCase{
{input: []float32{1}},
{input: []float32{0, 1, 2, 3}},
{input: []float32{0.1, 0.2, 0.3}},
{input: []float32{-0.1, 0.2, 0.3, -0.4}},
{input: []float32{0, 0, 0}},
}
isNormalized := func(vec []float32) (res bool) {
sum := 0.0
for _, v := range vec {
sum += float64(v * v)
}
if math.Abs(sum-1) > 1e-6 {
return sum == 0
} else {
return true
}
}
for _, tc := range testCases {
t.Run("", func(t *testing.T) {
normalized := normalize(tc.input)
if !isNormalized(normalized) {
t.Errorf("Vector %v is not normalized", tc.input)
}
})
}
}

View File

@ -642,8 +642,8 @@ type mockLlm struct {
pingResp error pingResp error
waitResp error waitResp error
completionResp error completionResp error
embeddingResp []float64 embedResp [][]float32
embeddingRespErr error embedRespErr error
tokenizeResp []int tokenizeResp []int
tokenizeRespErr error tokenizeRespErr error
detokenizeResp string detokenizeResp string
@ -660,8 +660,8 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
return s.completionResp return s.completionResp
} }
func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) { func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) {
return s.embeddingResp, s.embeddingRespErr return s.embedResp, s.embedRespErr
} }
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) { func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
return s.tokenizeResp, s.tokenizeRespErr return s.tokenizeResp, s.tokenizeRespErr

View File

@ -0,0 +1,67 @@
{{- if or .Tools .System }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>
{{- if .Tools }}# Safety Preamble
The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.
# System Preamble
## Basic Rules
You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.
{{ if .System }}# User Preamble
{{ .System }}
{{- end }}
## Available Tools
Here is a list of tools that you have available to you:
{{- range .Tools }}
```python
def {{ .Function.Name }}(
{{- range $name, $property := .Function.Parameters.Properties }}{{ $name }}: {{ $property.Type }}, {{ end }}) -> List[Dict]:
"""{{ .Function.Description }}
{{- if .Function.Parameters.Properties }}
Args:
{{- range $name, $property := .Function.Parameters.Properties }}
{{ $name }} ({{ $property.Type }}): {{ $property.Description }}
{{- end }}
{{- end }}
"""
pass
```
{{- end }}
{{- else if .System }}{{ .System }}
{{- end }}<|END_OF_TURN_TOKEN|>
{{- end }}
{{- range .Messages }}
{{- if eq .Role "system" }}
{{- continue }}
{{- end }}<|START_OF_TURN_TOKEN|>
{{- if eq .Role "user" }}<|USER_TOKEN|>{{ .Content }}
{{- else if eq .Role "assistant" }}<|CHATBOT_TOKEN|>
{{- if .Content }}{{ .Content }}
{{- else if .ToolCalls }}
Action: ```json
[
{{- range .ToolCalls }}
{
"tool_name": "{{ .Function.Name }}",
"parameters": {{ json .Function.Arguments }}
}
{{- end }}
]```
{{ continue }}
{{ end }}
{{- else if eq .Role "tool" }}<|SYSTEM_TOKEN|><results>
{{ .Content }}</results>
{{- end }}<|END_OF_TURN_TOKEN|>
{{- end }}
{{- if .Tools }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example:
```json
[
{
"tool_name": title of the tool in the specification,
"parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters
}
]```
{{- end }}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>

View File

@ -0,0 +1,39 @@
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># Safety Preamble
The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.
# System Preamble
## Basic Rules
You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.
# User Preamble
You are a knowledgable assistant. You can answer questions and perform tasks.
## Available Tools
Here is a list of tools that you have available to you:
```python
def get_current_weather(format: string, location: string, ) -> List[Dict]:
"""Get the current weather
Args:
format (string): The temperature unit to use. Infer this from the users location.
location (string): The city and state, e.g. San Francisco, CA
"""
pass
```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in Paris?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
Action: ```json
[
{
"tool_name": "get_current_weather",
"parameters": {"format":"celsius","location":"Paris, France"}
}
]```
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><results>
22</results><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>The current temperature in Paris, France is 22 degrees Celsius.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in San Francisco and Toronto?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example:
```json
[
{
"tool_name": title of the tool in the specification,
"parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters
}
]```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>

View File

@ -0,0 +1,31 @@
{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
{{- if .System }}
{{ .System }}
{{- end }}
In addition to plain text responses, you can chose to call one or more of the provided functions.
Use the following rule to decide when to call a function:
* if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
* if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
If you decide to call functions:
* prefix function calls with functools marker (no closing marker required)
* all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
* follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
* respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
* make sure you pick the right functions that match the user intent
Available functions as JSON spec:
{{- if .Tools }}
{{ json .Tools }}
{{- end }}<|eot_id|>
{{- end }}
{{- range .Messages }}<|start_header_id|>
{{- if or (eq .Role "user") (eq .Role "assistant") (eq .Role "tool") }}{{ .Role }}
{{- end }}<|end_header_id|>
{{- if .Content }}{{ .Content }}
{{- else if .ToolCalls }} functools[
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}{{ "}" }}
{{- end }}]
{{- end }}<|eot_id|>
{{- end }}<|start_header_id|>assistant<|end_header_id|>

17
server/testdata/tools/firefunction.out vendored Normal file
View File

@ -0,0 +1,17 @@
<|start_header_id|>system<|end_header_id|>
You are a knowledgable assistant. You can answer questions and perform tasks.
In addition to plain text responses, you can chose to call one or more of the provided functions.
Use the following rule to decide when to call a function:
* if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
* if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
If you decide to call functions:
* prefix function calls with functools marker (no closing marker required)
* all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
* follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
* respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
* make sure you pick the right functions that match the user intent
Available functions as JSON spec:
[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]<|eot_id|><|start_header_id|><|end_header_id|>You are a knowledgable assistant. You can answer questions and perform tasks.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|> functools[{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]<|eot_id|><|start_header_id|>tool<|end_header_id|>22<|eot_id|><|start_header_id|>assistant<|end_header_id|>The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

39
server/testdata/tools/messages.json vendored Normal file
View File

@ -0,0 +1,39 @@
[
{
"role": "system",
"content": "You are a knowledgable assistant. You can answer questions and perform tasks."
},
{
"role": "user",
"content": "What's the weather like today in Paris?"
},
{
"role": "assistant",
"tool_calls": [
{
"id": "89a1e453-0bce-4de3-a456-c54bed09c520",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": {
"location": "Paris, France",
"format": "celsius"
}
}
}
]
},
{
"role": "tool",
"tool_call_id": "89a1e453-0bce-4de3-a456-c54bed09c520",
"content": "22"
},
{
"role": "assistant",
"content": "The current temperature in Paris, France is 22 degrees Celsius."
},
{
"role": "user",
"content": "What's the weather like today in San Francisco and Toronto?"
}
]

15
server/testdata/tools/mistral.gotmpl vendored Normal file
View File

@ -0,0 +1,15 @@
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ json $.Tools }}[/AVAILABLE_TOOLS]
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
{{ end }}{{ .Content }}[/INST]
{{- else if eq .Role "assistant" }}
{{- if .Content }} {{ .Content }}</s>
{{- else if .ToolCalls }}[TOOL_CALLS] [
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
{{- end }}]</s>
{{- end }}
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
{{- end }}
{{- end }}

3
server/testdata/tools/mistral.out vendored Normal file
View File

@ -0,0 +1,3 @@
[INST] What's the weather like today in Paris?[/INST][TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]</s>[TOOL_RESULTS] {"content": 22}[/TOOL_RESULTS] The current temperature in Paris, France is 22 degrees Celsius.</s>[AVAILABLE_TOOLS] [{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}][/AVAILABLE_TOOLS][INST] You are a knowledgable assistant. You can answer questions and perform tasks.
What's the weather like today in San Francisco and Toronto?[/INST]

30
server/testdata/tools/tools.json vendored Normal file
View File

@ -0,0 +1,30 @@
[
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"format": {
"type": "string",
"enum": [
"celsius",
"fahrenheit"
],
"description": "The temperature unit to use. Infer this from the users location."
}
},
"required": [
"location",
"format"
]
}
}
}
]

View File

@ -103,16 +103,9 @@ var response = parse.ActionNode{
} }
var funcs = template.FuncMap{ var funcs = template.FuncMap{
// contents returns the contents of messages with an optional role filter "json": func(v any) string {
"contents": func(v []*api.Message, role ...string) string { b, _ := json.Marshal(v)
var parts []string return string(b)
for _, m := range v {
if len(role) == 0 || role[0] == "" || m.Role == role[0] {
parts = append(parts, m.Content)
}
}
return strings.Join(parts, "\n\n")
}, },
} }
@ -141,7 +134,7 @@ func (t *Template) Vars() []string {
var vars []string var vars []string
for _, tt := range t.Templates() { for _, tt := range t.Templates() {
for _, n := range tt.Root.Nodes { for _, n := range tt.Root.Nodes {
vars = append(vars, parseNode(n)...) vars = append(vars, Identifiers(n)...)
} }
} }
@ -157,32 +150,81 @@ func (t *Template) Vars() []string {
type Values struct { type Values struct {
Messages []api.Message Messages []api.Message
Tools []api.Tool
Prompt string
Suffix string
// forceLegacy is a flag used to test compatibility with legacy templates // forceLegacy is a flag used to test compatibility with legacy templates
forceLegacy bool forceLegacy bool
} }
func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
var walk func(parse.Node) parse.Node
walk = func(n parse.Node) parse.Node {
if fn(n) {
return n
}
switch t := n.(type) {
case *parse.ListNode:
for _, c := range t.Nodes {
if n := walk(c); n != nil {
return n
}
}
case *parse.BranchNode:
for _, n := range []*parse.ListNode{t.List, t.ElseList} {
if n != nil {
if n := walk(n); n != nil {
return n
}
}
}
case *parse.IfNode:
return walk(&t.BranchNode)
case *parse.WithNode:
return walk(&t.BranchNode)
case *parse.RangeNode:
return walk(&t.BranchNode)
}
return nil
}
if n := walk(t.Tree.Root); n != nil {
return (&template.Template{
Tree: &parse.Tree{
Root: &parse.ListNode{
Nodes: []parse.Node{n},
},
},
}).Funcs(funcs)
}
return nil
}
func (t *Template) Execute(w io.Writer, v Values) error { func (t *Template) Execute(w io.Writer, v Values) error {
collated := collate(v.Messages) system, messages := collate(v.Messages)
if !v.forceLegacy && slices.Contains(t.Vars(), "messages") { if v.Prompt != "" && v.Suffix != "" {
return t.Template.Execute(w, map[string]any{ return t.Template.Execute(w, map[string]any{
"Messages": collated, "Prompt": v.Prompt,
"Suffix": v.Suffix,
"Response": "",
})
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
return t.Template.Execute(w, map[string]any{
"System": system,
"Messages": messages,
"Tools": v.Tools,
}) })
} }
system = ""
var b bytes.Buffer var b bytes.Buffer
var system, prompt, response string var prompt, response string
for i, m := range collated { for _, m := range messages {
switch m.Role { execute := func() error {
case "system":
system = m.Content
case "user":
prompt = m.Content
case "assistant":
response = m.Content
}
if i != len(collated)-1 && prompt != "" && response != "" {
if err := t.Template.Execute(&b, map[string]any{ if err := t.Template.Execute(&b, map[string]any{
"System": system, "System": system,
"Prompt": prompt, "Prompt": prompt,
@ -194,25 +236,41 @@ func (t *Template) Execute(w io.Writer, v Values) error {
system = "" system = ""
prompt = "" prompt = ""
response = "" response = ""
return nil
}
switch m.Role {
case "system":
if prompt != "" || response != "" {
if err := execute(); err != nil {
return err
}
}
system = m.Content
case "user":
if response != "" {
if err := execute(); err != nil {
return err
}
}
prompt = m.Content
case "assistant":
response = m.Content
} }
} }
var cut bool var cut bool
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool { nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
switch t := n.(type) { if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
case *parse.ActionNode:
case *parse.FieldNode:
if slices.Contains(t.Ident, "Response") {
cut = true cut = true
} }
}
return cut return cut
}) })
tree := parse.Tree{Root: nodes.(*parse.ListNode)} tree := parse.Tree{Root: nodes.(*parse.ListNode)}
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{ if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
"System": "", "System": system,
"Prompt": prompt, "Prompt": prompt,
}); err != nil { }); err != nil {
return err return err
@ -223,11 +281,13 @@ func (t *Template) Execute(w io.Writer, v Values) error {
} }
// collate messages based on role. consecutive messages of the same role are merged // collate messages based on role. consecutive messages of the same role are merged
// into a single message. collate also pulls out and merges messages with Role == "system" // into a single message. collate also collects and returns all system messages.
// which are templated separately. As a side effect, it mangles message content adding image // collate mutates message content adding image tags ([img-%d]) as needed
// tags ([img-%d]) as needed func collate(msgs []api.Message) (string, []*api.Message) {
func collate(msgs []api.Message) (collated []*api.Message) {
var n int var n int
var system []string
var collated []*api.Message
for i := range msgs { for i := range msgs {
msg := msgs[i] msg := msgs[i]
for range msg.Images { for range msg.Images {
@ -240,6 +300,10 @@ func collate(msgs []api.Message) (collated []*api.Message) {
n++ n++
} }
if msg.Role == "system" {
system = append(system, msg.Content)
}
if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role { if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
collated[len(collated)-1].Content += "\n\n" + msg.Content collated[len(collated)-1].Content += "\n\n" + msg.Content
} else { } else {
@ -247,53 +311,49 @@ func collate(msgs []api.Message) (collated []*api.Message) {
} }
} }
return return strings.Join(system, "\n\n"), collated
} }
func parseNode(n parse.Node) []string { // Identifiers walks the node tree returning any identifiers it finds along the way
func Identifiers(n parse.Node) []string {
switch n := n.(type) { switch n := n.(type) {
case *parse.ListNode:
var names []string
for _, n := range n.Nodes {
names = append(names, Identifiers(n)...)
}
return names
case *parse.TemplateNode:
return Identifiers(n.Pipe)
case *parse.ActionNode: case *parse.ActionNode:
return parseNode(n.Pipe) return Identifiers(n.Pipe)
case *parse.BranchNode:
names := Identifiers(n.Pipe)
for _, n := range []*parse.ListNode{n.List, n.ElseList} {
if n != nil {
names = append(names, Identifiers(n)...)
}
}
return names
case *parse.IfNode: case *parse.IfNode:
names := parseNode(n.Pipe) return Identifiers(&n.BranchNode)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
case *parse.RangeNode: case *parse.RangeNode:
names := parseNode(n.Pipe) return Identifiers(&n.BranchNode)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
case *parse.WithNode: case *parse.WithNode:
names := parseNode(n.Pipe) return Identifiers(&n.BranchNode)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
case *parse.PipeNode: case *parse.PipeNode:
var names []string var names []string
for _, c := range n.Cmds { for _, c := range n.Cmds {
for _, a := range c.Args { for _, a := range c.Args {
names = append(names, parseNode(a)...) names = append(names, Identifiers(a)...)
} }
} }
return names
case *parse.ListNode:
var names []string
for _, n := range n.Nodes {
names = append(names, parseNode(n)...)
}
return names return names
case *parse.FieldNode: case *parse.FieldNode:
return n.Ident return n.Ident
case *parse.TemplateNode: case *parse.VariableNode:
return parseNode(n.Pipe) return n.Ident
} }
return nil return nil

View File

@ -216,13 +216,11 @@ func TestExecuteWithMessages(t *testing.T) {
{"response", `[INST] {{ if .System }}{{ .System }} {"response", `[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`}, {{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `{{- $system := contents .Messages "system" -}} {"messages", `[INST] {{ if .System }}{{ .System }}
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }}
{{- $system = "" }}
{{ end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }} {{ end }}
{{- end }} {{- range .Messages }}
{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
{{- end }}`}, {{- end }}`},
}, },
Values{ Values{
@ -243,13 +241,11 @@ func TestExecuteWithMessages(t *testing.T) {
{"response", `[INST] {{ if .System }}{{ .System }} {"response", `[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`}, {{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `{{- $system := contents .Messages "system" -}} {"messages", `[INST] {{ if .System }}{{ .System }}
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }}
{{- $system = "" }}
{{ end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }} {{ end }}
{{- end }} {{- range .Messages }}
{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
{{- end }}`}, {{- end }}`},
}, },
Values{ Values{
@ -364,35 +360,37 @@ Answer: `,
} }
} }
func TestFuncs(t *testing.T) { func TestExecuteWithSuffix(t *testing.T) {
t.Run("contents", func(t *testing.T) { tmpl, err := Parse(`{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
cases := map[string]string{ {{- else }}{{ .Prompt }}
"": "A\n\nB\n\nC\n\nD\n\nE\n\nF", {{- end }}`)
"system": "A\n\nF", if err != nil {
"user": "B\n\nE", t.Fatal(err)
"assistant": "C\n\nD",
} }
s := []*api.Message{ cases := []struct {
{Role: "system", Content: "A"}, name string
{Role: "user", Content: "B"}, values Values
{Role: "assistant", Content: "C"}, expect string
{Role: "assistant", Content: "D"}, }{
{Role: "user", Content: "E"}, {
{Role: "system", Content: "F"}, "message", Values{Messages: []api.Message{{Role: "user", Content: "hello"}}}, "hello",
},
{
"prompt suffix", Values{Prompt: "def add(", Suffix: "return x"}, "<PRE> def add( <SUF>return x <MID>",
},
} }
fn, ok := funcs["contents"].(func([]*api.Message, ...string) string) for _, tt := range cases {
if !ok { t.Run(tt.name, func(t *testing.T) {
t.Fatal("contents is not a function") var b bytes.Buffer
if err := tmpl.Execute(&b, tt.values); err != nil {
t.Fatal(err)
} }
for k, v := range cases { if diff := cmp.Diff(b.String(), tt.expect); diff != "" {
t.Run(k, func(t *testing.T) {
if diff := cmp.Diff(fn(s, k), v); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
}) })
} }
})
} }