From 179737feb7311fc57c507a93378a3ac15da3a346 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Thu, 11 Jul 2024 22:53:46 -0700 Subject: [PATCH 01/25] Clean up old files when installing on Windows (#5645) * app: always clean up install dir; force close applications * remove wildcard * revert `CloseApplications` * whitespace * update `LOCALAPPDATA` var --- app/ollama.iss | 3 +++ 1 file changed, 3 insertions(+) diff --git a/app/ollama.iss b/app/ollama.iss index e6502abd3..fef4a7b25 100644 --- a/app/ollama.iss +++ b/app/ollama.iss @@ -127,6 +127,9 @@ Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\models" Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\history" ; NOTE: if the user has a custom OLLAMA_MODELS it will be preserved +[InstallDelete] +Type: filesandordirs; Name: "{%LOCALAPPDATA}\Programs\Ollama" + [Messages] WizardReady=Ollama Windows Preview ReadyLabel1=%nLet's get you up and running with your own large language models. From 36c87c433b7d880ef8b3a2b05ef93b0cd1675520 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 12 Jul 2024 11:48:06 -0700 Subject: [PATCH 02/25] template: preprocess message and collect system --- template/template.go | 37 +++++++++++---------------- template/template_test.go | 53 ++++++--------------------------------- 2 files changed, 23 insertions(+), 67 deletions(-) diff --git a/template/template.go b/template/template.go index 21e1614d0..9b3516665 100644 --- a/template/template.go +++ b/template/template.go @@ -102,22 +102,8 @@ var response = parse.ActionNode{ }, } -var funcs = template.FuncMap{ - // contents returns the contents of messages with an optional role filter - "contents": func(v []*api.Message, role ...string) string { - var parts []string - for _, m := range v { - if len(role) == 0 || role[0] == "" || m.Role == role[0] { - parts = append(parts, m.Content) - } - } - - return strings.Join(parts, "\n\n") - }, -} - func Parse(s string) (*Template, error) { - tmpl := template.New("").Option("missingkey=zero").Funcs(funcs) + tmpl := template.New("").Option("missingkey=zero") tmpl, err := tmpl.Parse(s) if err != nil { @@ -163,15 +149,16 @@ type Values struct { } func (t *Template) Execute(w io.Writer, v Values) error { - collated := collate(v.Messages) + system, collated := collate(v.Messages) if !v.forceLegacy && slices.Contains(t.Vars(), "messages") { return t.Template.Execute(w, map[string]any{ + "System": system, "Messages": collated, }) } var b bytes.Buffer - var system, prompt, response string + var prompt, response string for i, m := range collated { switch m.Role { case "system": @@ -223,11 +210,13 @@ func (t *Template) Execute(w io.Writer, v Values) error { } // collate messages based on role. consecutive messages of the same role are merged -// into a single message. collate also pulls out and merges messages with Role == "system" -// which are templated separately. As a side effect, it mangles message content adding image -// tags ([img-%d]) as needed -func collate(msgs []api.Message) (collated []*api.Message) { +// into a single message. collate also collects and returns all system messages. +// collate mutates message content adding image tags ([img-%d]) as needed +func collate(msgs []api.Message) (string, []*api.Message) { var n int + + var system []string + var collated []*api.Message for i := range msgs { msg := msgs[i] for range msg.Images { @@ -240,6 +229,10 @@ func collate(msgs []api.Message) (collated []*api.Message) { n++ } + if msg.Role == "system" { + system = append(system, msg.Content) + } + if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role { collated[len(collated)-1].Content += "\n\n" + msg.Content } else { @@ -247,7 +240,7 @@ func collate(msgs []api.Message) (collated []*api.Message) { } } - return + return strings.Join(system, "\n\n"), collated } func parseNode(n parse.Node) []string { diff --git a/template/template_test.go b/template/template_test.go index 5e5f42570..c678f1b12 100644 --- a/template/template_test.go +++ b/template/template_test.go @@ -216,13 +216,11 @@ func TestExecuteWithMessages(t *testing.T) { {"response", `[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}[/INST] {{ .Response }}`}, - {"messages", `{{- $system := contents .Messages "system" -}} -{{- range $index, $_ := .Messages }} -{{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }} -{{- $system = "" }} + {"messages", `[INST] {{ if .System }}{{ .System }} -{{ end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }} -{{- end }} +{{ end }} +{{- range .Messages }} +{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }} {{- end }}`}, }, Values{ @@ -243,13 +241,11 @@ func TestExecuteWithMessages(t *testing.T) { {"response", `[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}[/INST] {{ .Response }}`}, - {"messages", `{{- $system := contents .Messages "system" -}} -{{- range $index, $_ := .Messages }} -{{- if eq .Role "user" }}[INST] {{ if $system }}{{ $system }} -{{- $system = "" }} + {"messages", `[INST] {{ if .System }}{{ .System }} -{{ end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }} -{{- end }} +{{ end }} +{{- range .Messages }} +{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }} {{- end }}`}, }, Values{ @@ -363,36 +359,3 @@ Answer: `, }) } } - -func TestFuncs(t *testing.T) { - t.Run("contents", func(t *testing.T) { - cases := map[string]string{ - "": "A\n\nB\n\nC\n\nD\n\nE\n\nF", - "system": "A\n\nF", - "user": "B\n\nE", - "assistant": "C\n\nD", - } - - s := []*api.Message{ - {Role: "system", Content: "A"}, - {Role: "user", Content: "B"}, - {Role: "assistant", Content: "C"}, - {Role: "assistant", Content: "D"}, - {Role: "user", Content: "E"}, - {Role: "system", Content: "F"}, - } - - fn, ok := funcs["contents"].(func([]*api.Message, ...string) string) - if !ok { - t.Fatal("contents is not a function") - } - - for k, v := range cases { - t.Run(k, func(t *testing.T) { - if diff := cmp.Diff(fn(s, k), v); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) - } - }) - } - }) -} From 33627331a370755ff5033c0fcd71d1c9210c9d96 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Fri, 12 Jul 2024 12:29:23 -0700 Subject: [PATCH 03/25] app: also clean up tempdir runners on install (#5646) --- app/ollama.iss | 1 + 1 file changed, 1 insertion(+) diff --git a/app/ollama.iss b/app/ollama.iss index fef4a7b25..6bedb9ff7 100644 --- a/app/ollama.iss +++ b/app/ollama.iss @@ -128,6 +128,7 @@ Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\history" ; NOTE: if the user has a custom OLLAMA_MODELS it will be preserved [InstallDelete] +Type: filesandordirs; Name: "{%TEMP}\ollama*" Type: filesandordirs; Name: "{%LOCALAPPDATA}\Programs\Ollama" [Messages] From 9ac0a7a50b8d7a0f0627b037c7632181bfbcca97 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Fri, 12 Jul 2024 15:41:31 -0700 Subject: [PATCH 04/25] remove template from tests --- cmd/interactive_test.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/cmd/interactive_test.go b/cmd/interactive_test.go index d9af01eb8..711f38604 100644 --- a/cmd/interactive_test.go +++ b/cmd/interactive_test.go @@ -59,7 +59,6 @@ func TestModelfileBuilder(t *testing.T) { opts := runOptions{ Model: "hork", System: "You are part horse and part shark, but all hork. Do horklike things", - Template: "This is a template.", Messages: []api.Message{ {Role: "user", Content: "Hey there hork!"}, {Role: "assistant", Content: "Yes it is true, I am half horse, half shark."}, @@ -75,7 +74,6 @@ func TestModelfileBuilder(t *testing.T) { mf := buildModelfile(opts) expectedModelfile := `FROM {{.Model}} SYSTEM """{{.System}}""" -TEMPLATE """{{.Template}}""" PARAMETER penalize_newline false PARAMETER seed 42 PARAMETER stop [hi there] @@ -97,7 +95,6 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark.""" mf = buildModelfile(opts) expectedModelfile = `FROM {{.ParentModel}} SYSTEM """{{.System}}""" -TEMPLATE """{{.Template}}""" PARAMETER penalize_newline false PARAMETER seed 42 PARAMETER stop [hi there] From 23ebbaa46ead40c44c20b707b0e53d954ea51dc5 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Fri, 12 Jul 2024 15:47:17 -0700 Subject: [PATCH 05/25] Revert "remove template from tests" This reverts commit 9ac0a7a50b8d7a0f0627b037c7632181bfbcca97. --- cmd/interactive_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cmd/interactive_test.go b/cmd/interactive_test.go index 711f38604..d9af01eb8 100644 --- a/cmd/interactive_test.go +++ b/cmd/interactive_test.go @@ -59,6 +59,7 @@ func TestModelfileBuilder(t *testing.T) { opts := runOptions{ Model: "hork", System: "You are part horse and part shark, but all hork. Do horklike things", + Template: "This is a template.", Messages: []api.Message{ {Role: "user", Content: "Hey there hork!"}, {Role: "assistant", Content: "Yes it is true, I am half horse, half shark."}, @@ -74,6 +75,7 @@ func TestModelfileBuilder(t *testing.T) { mf := buildModelfile(opts) expectedModelfile := `FROM {{.Model}} SYSTEM """{{.System}}""" +TEMPLATE """{{.Template}}""" PARAMETER penalize_newline false PARAMETER seed 42 PARAMETER stop [hi there] @@ -95,6 +97,7 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark.""" mf = buildModelfile(opts) expectedModelfile = `FROM {{.ParentModel}} SYSTEM """{{.System}}""" +TEMPLATE """{{.Template}}""" PARAMETER penalize_newline false PARAMETER seed 42 PARAMETER stop [hi there] From 22c5451fc28b20dd83a389c49d9caf6a1e50a9e3 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 12 Jul 2024 21:04:44 -0700 Subject: [PATCH 06/25] fix system prompt (#5662) * fix system prompt * execute template when hitting previous roles * fix tests --------- Co-authored-by: jmorganca --- server/prompt.go | 23 +++++++---------------- server/prompt_test.go | 18 ++++++++++++++++++ template/template.go | 40 ++++++++++++++++++++++++++-------------- 3 files changed, 51 insertions(+), 30 deletions(-) diff --git a/server/prompt.go b/server/prompt.go index 51d691a9f..abc5e61e1 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "log/slog" - "slices" "github.com/ollama/ollama/api" "github.com/ollama/ollama/llm" @@ -17,26 +16,18 @@ type tokenizeFunc func(context.Context, string) ([]int, error) // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the // latest message and 2) system messages func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) { - // pull out any system messages which should always be included in the prompt var system []api.Message - msgs = slices.DeleteFunc(msgs, func(m api.Message) bool { - if m.Role == "system" { - system = append(system, m) - return true - } - - return false - }) - - if len(system) == 0 && m.System != "" { - // add model system prompt since it wasn't provided - system = append(system, api.Message{Role: "system", Content: m.System}) - } - // always include the last message n := len(msgs) - 1 // in reverse, find all messages that fit into context window for i := n - 1; i >= 0; i-- { + system = make([]api.Message, 0) + for j := range i { + if msgs[j].Role == "system" { + system = append(system, msgs[j]) + } + } + var b bytes.Buffer if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil { return "", nil, err diff --git a/server/prompt_test.go b/server/prompt_test.go index 1435b143a..d8caf3ed2 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -6,6 +6,7 @@ import ( "strings" "testing" + "github.com/google/go-cmp/cmp" "github.com/ollama/ollama/api" "github.com/ollama/ollama/template" ) @@ -164,6 +165,19 @@ func TestChatPrompt(t *testing.T) { prompt: "You are the Test Who Lived. You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ", }, }, + { + name: "out of order system", + limit: 2048, + msgs: []api.Message{ + {Role: "user", Content: "You're a test, Harry!"}, + {Role: "assistant", Content: "I-I'm a what?"}, + {Role: "system", Content: "You are the Test Who Lived."}, + {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."}, + }, + expect: expect{ + prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ", + }, + }, } tmpl, err := template.Parse(` @@ -187,6 +201,10 @@ func TestChatPrompt(t *testing.T) { t.Errorf("expected %q, got %q", tt.prompt, prompt) } + if diff := cmp.Diff(prompt, tt.prompt); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + if len(images) != len(tt.images) { t.Fatalf("expected %d images, got %d", len(tt.images), len(images)) } diff --git a/template/template.go b/template/template.go index 9b3516665..90014ec1a 100644 --- a/template/template.go +++ b/template/template.go @@ -149,27 +149,19 @@ type Values struct { } func (t *Template) Execute(w io.Writer, v Values) error { - system, collated := collate(v.Messages) + system, messages := collate(v.Messages) if !v.forceLegacy && slices.Contains(t.Vars(), "messages") { return t.Template.Execute(w, map[string]any{ "System": system, - "Messages": collated, + "Messages": messages, }) } + system = "" var b bytes.Buffer var prompt, response string - for i, m := range collated { - switch m.Role { - case "system": - system = m.Content - case "user": - prompt = m.Content - case "assistant": - response = m.Content - } - - if i != len(collated)-1 && prompt != "" && response != "" { + for _, m := range messages { + execute := func () error { if err := t.Template.Execute(&b, map[string]any{ "System": system, "Prompt": prompt, @@ -181,6 +173,26 @@ func (t *Template) Execute(w io.Writer, v Values) error { system = "" prompt = "" response = "" + return nil + } + + switch m.Role { + case "system": + if prompt != "" || response != "" { + if err := execute(); err != nil { + return err + } + } + system = m.Content + case "user": + if response != "" { + if err := execute(); err != nil { + return err + } + } + prompt = m.Content + case "assistant": + response = m.Content } } @@ -199,7 +211,7 @@ func (t *Template) Execute(w io.Writer, v Values) error { tree := parse.Tree{Root: nodes.(*parse.ListNode)} if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{ - "System": "", + "System": system, "Prompt": prompt, }); err != nil { return err From 02fea420e5a0042d5e4cfbb5024a6d7e092dc789 Mon Sep 17 00:00:00 2001 From: Jarek Date: Sat, 13 Jul 2024 17:33:46 +0200 Subject: [PATCH 07/25] Add Kerlig AI, an app for macOS (#5675) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 62f5cd65c..eb5e85329 100644 --- a/README.md +++ b/README.md @@ -293,6 +293,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS) - [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama) - [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama) +- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS) ### Terminal From ef98803d63a4e4c56853688343f011256ced130d Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Sat, 13 Jul 2024 09:20:05 -0700 Subject: [PATCH 08/25] llm: looser checks for minimum memory (#5677) --- llm/server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llm/server.go b/llm/server.go index 8f37aa23a..ffed9fc02 100644 --- a/llm/server.go +++ b/llm/server.go @@ -127,7 +127,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr // On linux, over-allocating CPU memory will almost always result in an error if runtime.GOOS == "linux" { systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize - available := min(systemTotalMemory, systemFreeMemory+systemSwapFreeMemory) + available := systemFreeMemory + systemSwapFreeMemory if systemMemoryRequired > available { slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", available, "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "swap", format.HumanBytes2(systemSwapFreeMemory)) return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available)) From 1ed0aa8feab58a5cbdf2d79fdb718e3a5cc03525 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Sat, 13 Jul 2024 09:25:31 -0700 Subject: [PATCH 09/25] server: fix `context`, `load_duration` and `total_duration` fields (#5676) * server: fix `contet`, `load_duration` and `total_duration` fields * Update server/routes.go --- server/routes.go | 56 +++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 46 insertions(+), 10 deletions(-) diff --git a/server/routes.go b/server/routes.go index 4059c7c52..5b6d09788 100644 --- a/server/routes.go +++ b/server/routes.go @@ -102,6 +102,7 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil } func (s *Server) GenerateHandler(c *gin.Context) { + checkpointStart := time.Now() var req api.GenerateRequest if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) @@ -129,6 +130,8 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } + checkpointLoaded := time.Now() + if req.Prompt == "" { c.JSON(http.StatusOK, api.GenerateResponse{ Model: req.Model, @@ -191,26 +194,48 @@ func (s *Server) GenerateHandler(c *gin.Context) { ch := make(chan any) go func() { + // TODO (jmorganca): avoid building the response twice both here and below + var sb strings.Builder defer close(ch) if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: prompt, Images: images, Format: req.Format, Options: opts, - }, func(r llm.CompletionResponse) { - ch <- api.GenerateResponse{ + }, func(cr llm.CompletionResponse) { + res := api.GenerateResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), - Response: r.Content, - Done: r.Done, - DoneReason: r.DoneReason, + Response: cr.Content, + Done: cr.Done, + DoneReason: cr.DoneReason, Metrics: api.Metrics{ - PromptEvalCount: r.PromptEvalCount, - PromptEvalDuration: r.PromptEvalDuration, - EvalCount: r.EvalCount, - EvalDuration: r.EvalDuration, + PromptEvalCount: cr.PromptEvalCount, + PromptEvalDuration: cr.PromptEvalDuration, + EvalCount: cr.EvalCount, + EvalDuration: cr.EvalDuration, }, } + + if _, err := sb.WriteString(cr.Content); err != nil { + ch <- gin.H{"error": err.Error()} + } + + if cr.Done { + res.TotalDuration = time.Since(checkpointStart) + res.LoadDuration = checkpointLoaded.Sub(checkpointStart) + + if !req.Raw { + tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String()) + if err != nil { + ch <- gin.H{"error": err.Error()} + return + } + res.Context = append(req.Context, tokens...) + } + } + + ch <- res }); err != nil { ch <- gin.H{"error": err.Error()} } @@ -1122,6 +1147,8 @@ func (s *Server) ProcessHandler(c *gin.Context) { } func (s *Server) ChatHandler(c *gin.Context) { + checkpointStart := time.Now() + var req api.ChatRequest if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) @@ -1141,6 +1168,8 @@ func (s *Server) ChatHandler(c *gin.Context) { return } + checkpointLoaded := time.Now() + if len(req.Messages) == 0 { c.JSON(http.StatusOK, api.ChatResponse{ Model: req.Model, @@ -1169,7 +1198,7 @@ func (s *Server) ChatHandler(c *gin.Context) { Format: req.Format, Options: opts, }, func(r llm.CompletionResponse) { - ch <- api.ChatResponse{ + res := api.ChatResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), Message: api.Message{Role: "assistant", Content: r.Content}, @@ -1182,6 +1211,13 @@ func (s *Server) ChatHandler(c *gin.Context) { EvalDuration: r.EvalDuration, }, } + + if r.Done { + res.TotalDuration = time.Since(checkpointStart) + res.LoadDuration = checkpointLoaded.Sub(checkpointStart) + } + + ch <- res }); err != nil { ch <- gin.H{"error": err.Error()} } From f7ee0123008dbdb3fd5954438d12196951b58b78 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Sat, 13 Jul 2024 15:08:00 -0700 Subject: [PATCH 10/25] server: prepend system message in chat handler --- server/routes.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/server/routes.go b/server/routes.go index 5b6d09788..edaec6912 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1181,6 +1181,10 @@ func (s *Server) ChatHandler(c *gin.Context) { return } + if req.Messages[0].Role != "system" { + req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...) + } + prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) From 057d31861e3514b60a7eedf694899067b72bd2fa Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Sat, 13 Jul 2024 20:56:24 -0700 Subject: [PATCH 11/25] remove template (#5655) --- api/types.go | 2 ++ cmd/cmd.go | 2 -- cmd/interactive.go | 52 +++++++++++------------------------------ cmd/interactive_test.go | 3 --- server/routes.go | 7 ------ 5 files changed, 16 insertions(+), 50 deletions(-) diff --git a/api/types.go b/api/types.go index 87844c67c..91c97c715 100644 --- a/api/types.go +++ b/api/types.go @@ -221,6 +221,8 @@ type DeleteRequest struct { type ShowRequest struct { Model string `json:"model"` System string `json:"system"` + + // Template is deprecated Template string `json:"template"` Verbose bool `json:"verbose"` diff --git a/cmd/cmd.go b/cmd/cmd.go index c898c7db6..2252a905e 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -843,7 +843,6 @@ type runOptions struct { WordWrap bool Format string System string - Template string Images []api.ImageData Options map[string]interface{} MultiModal bool @@ -1037,7 +1036,6 @@ func generate(cmd *cobra.Command, opts runOptions) error { Images: opts.Images, Format: opts.Format, System: opts.System, - Template: opts.Template, Options: opts.Options, KeepAlive: opts.KeepAlive, } diff --git a/cmd/interactive.go b/cmd/interactive.go index 9214f2db5..adbc3e9fb 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -27,7 +27,6 @@ const ( MultilineNone MultilineState = iota MultilinePrompt MultilineSystem - MultilineTemplate ) func loadModel(cmd *cobra.Command, opts *runOptions) error { @@ -94,7 +93,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { fmt.Fprintln(os.Stderr, "Available Commands:") fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter") fmt.Fprintln(os.Stderr, " /set system Set system message") - fmt.Fprintln(os.Stderr, " /set template Set prompt template") fmt.Fprintln(os.Stderr, " /set history Enable history") fmt.Fprintln(os.Stderr, " /set nohistory Disable history") fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap") @@ -204,10 +202,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System}) fmt.Println("Set system message.") sb.Reset() - case MultilineTemplate: - opts.Template = sb.String() - fmt.Println("Set prompt template.") - sb.Reset() } multiline = MultilineNone @@ -326,17 +320,13 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { } fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", ")) opts.Options[args[2]] = fp[args[2]] - case "system", "template": + case "system": if len(args) < 3 { usageSet() continue } - if args[1] == "system" { - multiline = MultilineSystem - } else if args[1] == "template" { - multiline = MultilineTemplate - } + multiline = MultilineSystem line := strings.Join(args[2:], " ") line, ok := strings.CutPrefix(line, `"""`) @@ -356,23 +346,17 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { continue } - if args[1] == "system" { - opts.System = sb.String() // for display in modelfile - newMessage := api.Message{Role: "system", Content: sb.String()} - // Check if the slice is not empty and the last message is from 'system' - if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" { - // Replace the last message - opts.Messages[len(opts.Messages)-1] = newMessage - } else { - opts.Messages = append(opts.Messages, newMessage) - } - fmt.Println("Set system message.") - sb.Reset() - } else if args[1] == "template" { - opts.Template = sb.String() - fmt.Println("Set prompt template.") - sb.Reset() + opts.System = sb.String() // for display in modelfile + newMessage := api.Message{Role: "system", Content: sb.String()} + // Check if the slice is not empty and the last message is from 'system' + if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" { + // Replace the last message + opts.Messages[len(opts.Messages)-1] = newMessage + } else { + opts.Messages = append(opts.Messages, newMessage) } + fmt.Println("Set system message.") + sb.Reset() sb.Reset() continue @@ -393,7 +377,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { req := &api.ShowRequest{ Name: opts.Model, System: opts.System, - Template: opts.Template, Options: opts.Options, } resp, err := client.Show(cmd.Context(), req) @@ -437,12 +420,9 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { fmt.Println("No system message was specified for this model.") } case "template": - switch { - case opts.Template != "": - fmt.Println(opts.Template + "\n") - case resp.Template != "": + if resp.Template != "" { fmt.Println(resp.Template) - default: + } else { fmt.Println("No prompt template was specified for this model.") } default: @@ -536,10 +516,6 @@ func buildModelfile(opts runOptions) string { fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System) } - if opts.Template != "" { - fmt.Fprintf(&mf, "TEMPLATE \"\"\"%s\"\"\"\n", opts.Template) - } - keys := make([]string, 0) for k := range opts.Options { keys = append(keys, k) diff --git a/cmd/interactive_test.go b/cmd/interactive_test.go index d9af01eb8..711f38604 100644 --- a/cmd/interactive_test.go +++ b/cmd/interactive_test.go @@ -59,7 +59,6 @@ func TestModelfileBuilder(t *testing.T) { opts := runOptions{ Model: "hork", System: "You are part horse and part shark, but all hork. Do horklike things", - Template: "This is a template.", Messages: []api.Message{ {Role: "user", Content: "Hey there hork!"}, {Role: "assistant", Content: "Yes it is true, I am half horse, half shark."}, @@ -75,7 +74,6 @@ func TestModelfileBuilder(t *testing.T) { mf := buildModelfile(opts) expectedModelfile := `FROM {{.Model}} SYSTEM """{{.System}}""" -TEMPLATE """{{.Template}}""" PARAMETER penalize_newline false PARAMETER seed 42 PARAMETER stop [hi there] @@ -97,7 +95,6 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark.""" mf = buildModelfile(opts) expectedModelfile = `FROM {{.ParentModel}} SYSTEM """{{.System}}""" -TEMPLATE """{{.Template}}""" PARAMETER penalize_newline false PARAMETER seed 42 PARAMETER stop [hi there] diff --git a/server/routes.go b/server/routes.go index edaec6912..0a00d9e23 100644 --- a/server/routes.go +++ b/server/routes.go @@ -574,13 +574,6 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { m.System = req.System } - if req.Template != "" { - m.Template, err = template.Parse(req.Template) - if err != nil { - return nil, err - } - } - msgs := make([]api.Message, len(m.Messages)) for i, msg := range m.Messages { msgs[i] = api.Message{Role: msg.Role, Content: msg.Content} From e9f7f3602961d2b0beaff27144ec89301c2173ca Mon Sep 17 00:00:00 2001 From: royjhan <65097070+royjhan@users.noreply.github.com> Date: Sat, 13 Jul 2024 22:07:45 -0700 Subject: [PATCH 12/25] Support image input for OpenAI chat compatibility (#5208) * OpenAI v1 models * Refactor Writers * Add Test Co-Authored-By: Attila Kerekes * Credit Co-Author Co-Authored-By: Attila Kerekes <439392+keriati@users.noreply.github.com> * Empty List Testing * Use Namespace for Ownedby * Update Test * Add back envconfig * v1/models docs * Use ModelName Parser * Test Names * Remove Docs * Clean Up * Test name Co-authored-by: Jeffrey Morgan * Add Middleware for Chat and List * Testing Cleanup * Test with Fatal * Add functionality to chat test * Support image input for OpenAI chat * Decoding * Fix message processing logic * openai vision test * type errors * clean up * redundant check * merge conflicts * merge conflicts * merge conflicts * flattening and smaller image * add test * support python and js SDKs and mandate prefixing * clean up --------- Co-authored-by: Attila Kerekes <439392+keriati@users.noreply.github.com> Co-authored-by: Jeffrey Morgan --- openai/openai.go | 76 +++++++++++++++++++++++++++++++++++++++---- openai/openai_test.go | 49 ++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 6 deletions(-) diff --git a/openai/openai.go b/openai/openai.go index 1707da14b..b289d73e8 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -3,11 +3,13 @@ package openai import ( "bytes" + "encoding/base64" "encoding/json" "fmt" "io" "math/rand" "net/http" + "strings" "time" "github.com/gin-gonic/gin" @@ -28,7 +30,7 @@ type ErrorResponse struct { type Message struct { Role string `json:"role"` - Content string `json:"content"` + Content any `json:"content"` } type Choice struct { @@ -269,10 +271,66 @@ func toModel(r api.ShowResponse, m string) Model { } } -func fromChatRequest(r ChatCompletionRequest) api.ChatRequest { +func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { var messages []api.Message for _, msg := range r.Messages { - messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content}) + switch content := msg.Content.(type) { + case string: + messages = append(messages, api.Message{Role: msg.Role, Content: content}) + case []any: + message := api.Message{Role: msg.Role} + for _, c := range content { + data, ok := c.(map[string]any) + if !ok { + return nil, fmt.Errorf("invalid message format") + } + switch data["type"] { + case "text": + text, ok := data["text"].(string) + if !ok { + return nil, fmt.Errorf("invalid message format") + } + message.Content = text + case "image_url": + var url string + if urlMap, ok := data["image_url"].(map[string]any); ok { + if url, ok = urlMap["url"].(string); !ok { + return nil, fmt.Errorf("invalid message format") + } + } else { + if url, ok = data["image_url"].(string); !ok { + return nil, fmt.Errorf("invalid message format") + } + } + + types := []string{"jpeg", "jpg", "png"} + valid := false + for _, t := range types { + prefix := "data:image/" + t + ";base64," + if strings.HasPrefix(url, prefix) { + url = strings.TrimPrefix(url, prefix) + valid = true + break + } + } + + if !valid { + return nil, fmt.Errorf("invalid image input") + } + + img, err := base64.StdEncoding.DecodeString(url) + if err != nil { + return nil, fmt.Errorf("invalid message format") + } + message.Images = append(message.Images, img) + default: + return nil, fmt.Errorf("invalid message format") + } + } + messages = append(messages, message) + default: + return nil, fmt.Errorf("invalid message content type: %T", content) + } } options := make(map[string]interface{}) @@ -323,13 +381,13 @@ func fromChatRequest(r ChatCompletionRequest) api.ChatRequest { format = "json" } - return api.ChatRequest{ + return &api.ChatRequest{ Model: r.Model, Messages: messages, Format: format, Options: options, Stream: &r.Stream, - } + }, nil } func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) { @@ -656,7 +714,13 @@ func ChatMiddleware() gin.HandlerFunc { } var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(fromChatRequest(req)); err != nil { + + chatReq, err := fromChatRequest(req) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) + } + + if err := json.NewEncoder(&b).Encode(chatReq); err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) return } diff --git a/openai/openai_test.go b/openai/openai_test.go index 5f1ae52e9..99f8baaf3 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -2,6 +2,7 @@ package openai import ( "bytes" + "encoding/base64" "encoding/json" "io" "net/http" @@ -15,6 +16,10 @@ import ( "github.com/stretchr/testify/assert" ) +const prefix = `data:image/jpeg;base64,` +const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` +const imageURL = prefix + image + func TestMiddlewareRequests(t *testing.T) { type testCase struct { Name string @@ -112,6 +117,50 @@ func TestMiddlewareRequests(t *testing.T) { } }, }, + { + Name: "chat handler with image content", + Method: http.MethodPost, + Path: "/api/chat", + Handler: ChatMiddleware, + Setup: func(t *testing.T, req *http.Request) { + body := ChatCompletionRequest{ + Model: "test-model", + Messages: []Message{ + { + Role: "user", Content: []map[string]any{ + {"type": "text", "text": "Hello"}, + {"type": "image_url", "image_url": map[string]string{"url": imageURL}}, + }, + }, + }, + } + + bodyBytes, _ := json.Marshal(body) + + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + }, + Expected: func(t *testing.T, req *http.Request) { + var chatReq api.ChatRequest + if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil { + t.Fatal(err) + } + + if chatReq.Messages[0].Role != "user" { + t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role) + } + + if chatReq.Messages[0].Content != "Hello" { + t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content) + } + + img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):]) + + if !bytes.Equal(chatReq.Messages[0].Images[0], img) { + t.Fatalf("expected image encoding, got %s", chatReq.Messages[0].Images[0]) + } + }, + }, } gin.SetMode(gin.TestMode) From b9f5e16c8025f115abde34ff047023f4d6e34af5 Mon Sep 17 00:00:00 2001 From: royjhan <65097070+royjhan@users.noreply.github.com> Date: Mon, 15 Jul 2024 12:14:24 -0700 Subject: [PATCH 13/25] Introduce `/api/embed` endpoint supporting batch embedding (#5127) * Initial Batch Embedding * Revert "Initial Batch Embedding" This reverts commit c22d54895a280b54c727279d85a5fc94defb5a29. * Initial Draft * mock up notes * api/embed draft * add server function * check normalization * clean up * normalization * playing around with truncate stuff * Truncation * Truncation * move normalization to go * Integration Test Template * Truncation Integration Tests * Clean up * use float32 * move normalize * move normalize test * refactoring * integration float32 * input handling and handler testing * Refactoring of legacy and new * clear comments * merge conflicts * touches * embedding type 64 * merge conflicts * fix hanging on single string * refactoring * test values * set context length * clean up * testing clean up * testing clean up * remove function closure * Revert "remove function closure" This reverts commit 55d48c6ed17abe42e7a122e69d603ef0c1506787. * remove function closure * remove redundant error check * clean up * more clean up * clean up --- api/client.go | 11 ++- api/types.go | 24 ++++++ integration/embed_test.go | 152 ++++++++++++++++++++++++++++++++++++++ llm/ext_server/server.cpp | 37 ++++++---- llm/server.go | 16 ++-- server/routes.go | 131 +++++++++++++++++++++++++++++++- server/routes_test.go | 103 ++++++++++++++++++++++++++ server/sched_test.go | 8 +- 8 files changed, 452 insertions(+), 30 deletions(-) create mode 100644 integration/embed_test.go diff --git a/api/client.go b/api/client.go index fccbc9ad7..c59fbc423 100644 --- a/api/client.go +++ b/api/client.go @@ -347,7 +347,16 @@ func (c *Client) Heartbeat(ctx context.Context) error { return nil } -// Embeddings generates embeddings from a model. +// Embed generates embeddings from a model. +func (c *Client) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { + var resp EmbedResponse + if err := c.do(ctx, http.MethodPost, "/api/embed", req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// Embeddings generates an embedding from a model. func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) { var resp EmbeddingResponse if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil { diff --git a/api/types.go b/api/types.go index 91c97c715..bf5529283 100644 --- a/api/types.go +++ b/api/types.go @@ -173,6 +173,30 @@ type Runner struct { NumThread int `json:"num_thread,omitempty"` } +// EmbedRequest is the request passed to [Client.Embed]. +type EmbedRequest struct { + // Model is the model name. + Model string `json:"model"` + + // Input is the input to embed. + Input any `json:"input"` + + // KeepAlive controls how long the model will stay loaded in memory following + // this request. + KeepAlive *Duration `json:"keep_alive,omitempty"` + + Truncate *bool `json:"truncate,omitempty"` + + // Options lists model-specific options. + Options map[string]interface{} `json:"options"` +} + +// EmbedResponse is the response from [Client.Embed]. +type EmbedResponse struct { + Model string `json:"model"` + Embeddings [][]float32 `json:"embeddings,omitempty"` +} + // EmbeddingRequest is the request passed to [Client.Embeddings]. type EmbeddingRequest struct { // Model is the model name. diff --git a/integration/embed_test.go b/integration/embed_test.go new file mode 100644 index 000000000..aeafa57b6 --- /dev/null +++ b/integration/embed_test.go @@ -0,0 +1,152 @@ +//go:build integration + +package integration + +import ( + "context" + "testing" + "time" + + "github.com/ollama/ollama/api" +) + +func TestAllMiniLMEmbed(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + req := api.EmbedRequest{ + Model: "all-minilm", + Input: "why is the sky blue?", + } + + res, err := embedTestHelper(ctx, t, req) + + if err != nil { + t.Fatalf("error: %v", err) + } + + if len(res.Embeddings) != 1 { + t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings)) + } + + if len(res.Embeddings[0]) != 384 { + t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0])) + } + + if res.Embeddings[0][0] != 0.010071031 { + t.Fatalf("expected 0.010071031, got %f", res.Embeddings[0][0]) + } +} + +func TestAllMiniLMBatchEmbed(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + req := api.EmbedRequest{ + Model: "all-minilm", + Input: []string{"why is the sky blue?", "why is the grass green?"}, + } + + res, err := embedTestHelper(ctx, t, req) + + if err != nil { + t.Fatalf("error: %v", err) + } + + if len(res.Embeddings) != 2 { + t.Fatalf("expected 2 embeddings, got %d", len(res.Embeddings)) + } + + if len(res.Embeddings[0]) != 384 { + t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0])) + } + + if res.Embeddings[0][0] != 0.010071031 || res.Embeddings[1][0] != -0.009802706 { + t.Fatalf("expected 0.010071031 and -0.009802706, got %f and %f", res.Embeddings[0][0], res.Embeddings[1][0]) + } +} + +func TestAllMiniLmEmbedTruncate(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + truncTrue, truncFalse := true, false + + type testReq struct { + Name string + Request api.EmbedRequest + } + + reqs := []testReq{ + { + Name: "Target Truncation", + Request: api.EmbedRequest{ + Model: "all-minilm", + Input: "why", + }, + }, + { + Name: "Default Truncate", + Request: api.EmbedRequest{ + Model: "all-minilm", + Input: "why is the sky blue?", + Options: map[string]any{"num_ctx": 1}, + }, + }, + { + Name: "Explicit Truncate", + Request: api.EmbedRequest{ + Model: "all-minilm", + Input: "why is the sky blue?", + Truncate: &truncTrue, + Options: map[string]any{"num_ctx": 1}, + }, + }, + } + + res := make(map[string]*api.EmbedResponse) + + for _, req := range reqs { + response, err := embedTestHelper(ctx, t, req.Request) + if err != nil { + t.Fatalf("error: %v", err) + } + res[req.Name] = response + } + + if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] { + t.Fatal("expected default request to truncate correctly") + } + + if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] { + t.Fatal("expected default request and truncate true request to be the same") + } + + // check that truncate set to false returns an error if context length is exceeded + _, err := embedTestHelper(ctx, t, api.EmbedRequest{ + Model: "all-minilm", + Input: "why is the sky blue?", + Truncate: &truncFalse, + Options: map[string]any{"num_ctx": 1}, + }) + + if err == nil { + t.Fatal("expected error, got nil") + } +} + +func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) { + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + if err := PullIfMissing(ctx, client, req.Model); err != nil { + t.Fatalf("failed to pull model %s: %v", req.Model, err) + } + + response, err := client.Embed(ctx, &req) + + if err != nil { + return nil, err + } + + return response, nil +} diff --git a/llm/ext_server/server.cpp b/llm/ext_server/server.cpp index 0ef3956ec..e8a076c43 100644 --- a/llm/ext_server/server.cpp +++ b/llm/ext_server/server.cpp @@ -3188,26 +3188,33 @@ int main(int argc, char **argv) { prompt = ""; } - json image_data; - if (body.count("image_data") != 0) { - image_data = body["image_data"]; - } - else - { - image_data = ""; + if (prompt.size() == 1) { + prompt = prompt[0]; } // create and queue the task - const int task_id = llama.queue_tasks.get_new_id(); - llama.queue_results.add_waiting_task_id(task_id); - llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, true, -1); + json responses; + { + const int id_task = llama.queue_tasks.get_new_id(); + llama.queue_results.add_waiting_task_id(id_task); + llama.request_completion(id_task, {{"prompt", prompt}}, true, -1); - // get the result - task_result result = llama.queue_results.recv(task_id); - llama.queue_results.remove_waiting_task_id(task_id); + // get the result + task_result result = llama.queue_results.recv(id_task); + llama.queue_results.remove_waiting_task_id(id_task); + if (result.error) { + return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); + } - // send the result - return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); + responses = result.result_json.value("results", std::vector{result.result_json}); + json embeddings = json::array(); + for (auto & elem : responses) { + embeddings.push_back(elem.at("embedding")); + } + // send the result + json embedding_res = json{{"embedding", embeddings}}; + return res.set_content(embedding_res.dump(), "application/json; charset=utf-8"); + } }); // GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!? diff --git a/llm/server.go b/llm/server.go index ffed9fc02..36c0e0b55 100644 --- a/llm/server.go +++ b/llm/server.go @@ -33,7 +33,7 @@ type LlamaServer interface { Ping(ctx context.Context) error WaitUntilRunning(ctx context.Context) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error - Embedding(ctx context.Context, prompt string) ([]float64, error) + Embed(ctx context.Context, input []string) ([][]float32, error) Tokenize(ctx context.Context, content string) ([]int, error) Detokenize(ctx context.Context, tokens []int) (string, error) Close() error @@ -867,15 +867,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu return nil } -type EmbeddingRequest struct { - Content string `json:"content"` +type EmbedRequest struct { + Content []string `json:"content"` } -type EmbeddingResponse struct { - Embedding []float64 `json:"embedding"` +type EmbedResponse struct { + Embedding [][]float32 `json:"embedding"` } -func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) { +func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) { if err := s.sem.Acquire(ctx, 1); err != nil { slog.Error("Failed to acquire semaphore", "error", err) return nil, err @@ -890,7 +890,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) } - data, err := json.Marshal(TokenizeRequest{Content: prompt}) + data, err := json.Marshal(EmbedRequest{Content: input}) if err != nil { return nil, fmt.Errorf("error marshaling embed data: %w", err) } @@ -917,7 +917,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er return nil, fmt.Errorf("%s", body) } - var embedding EmbeddingResponse + var embedding EmbedResponse if err := json.Unmarshal(body, &embedding); err != nil { return nil, fmt.Errorf("unmarshal tokenize response: %w", err) } diff --git a/server/routes.go b/server/routes.go index 0a00d9e23..c5c3a19ca 100644 --- a/server/routes.go +++ b/server/routes.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "log/slog" + "math" "net" "net/http" "net/netip" @@ -271,6 +272,121 @@ func (s *Server) GenerateHandler(c *gin.Context) { streamResponse(c, ch) } +func (s *Server) EmbedHandler(c *gin.Context) { + var req api.EmbedRequest + err := c.ShouldBindJSON(&req) + switch { + case errors.Is(err, io.EOF): + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) + return + case err != nil: + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + truncate := true + + if req.Truncate != nil && !*req.Truncate { + truncate = false + } + + var input []string + + switch i := req.Input.(type) { + case string: + if len(i) > 0 { + input = append(input, i) + } + case []any: + for _, v := range i { + if _, ok := v.(string); !ok { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"}) + return + } + input = append(input, v.(string)) + } + default: + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"}) + return + } + + if len(input) == 0 { + c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}}) + return + } + + r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) + if err != nil { + handleScheduleError(c, req.Model, err) + return + } + + kvData, err := getKVData(m.ModelPath, false) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + for i, s := range input { + tokens, err := r.Tokenize(c.Request.Context(), s) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + ctxLen := min(opts.NumCtx, int(kvData.ContextLength())) + if len(tokens) > ctxLen { + if !truncate { + c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"}) + return + } + + tokens = tokens[:ctxLen] + s, err = r.Detokenize(c.Request.Context(), tokens) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + + input[i] = s + } + embeddings, err := r.Embed(c.Request.Context(), input) + + if err != nil { + slog.Error("embedding generation failed", "error", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) + return + } + + for i, e := range embeddings { + embeddings[i] = normalize(e) + } + + resp := api.EmbedResponse{ + Model: req.Model, + Embeddings: embeddings, + } + c.JSON(http.StatusOK, resp) +} + +func normalize(vec []float32) []float32 { + var sum float32 + for _, v := range vec { + sum += v * v + } + + norm := float32(0.0) + if sum > 0 { + norm = float32(1.0 / math.Sqrt(float64(sum))) + } + + for i := range vec { + vec[i] *= norm + } + return vec +} + func (s *Server) EmbeddingsHandler(c *gin.Context) { var req api.EmbeddingRequest if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { @@ -293,14 +409,24 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { return } - embedding, err := r.Embedding(c.Request.Context(), req.Prompt) + embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt}) + if err != nil { slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) return } - c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: embedding}) + embedding := make([]float64, len(embeddings[0])) + + for i, v := range embeddings[0] { + embedding[i] = float64(v) + } + + resp := api.EmbeddingResponse{ + Embedding: embedding, + } + c.JSON(http.StatusOK, resp) } func (s *Server) PullModelHandler(c *gin.Context) { @@ -919,6 +1045,7 @@ func (s *Server) GenerateRoutes() http.Handler { r.POST("/api/pull", s.PullModelHandler) r.POST("/api/generate", s.GenerateHandler) r.POST("/api/chat", s.ChatHandler) + r.POST("/api/embed", s.EmbedHandler) r.POST("/api/embeddings", s.EmbeddingsHandler) r.POST("/api/create", s.CreateModelHandler) r.POST("/api/push", s.PushModelHandler) diff --git a/server/routes_test.go b/server/routes_test.go index 50eaf7e97..70622e9b0 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "io" + "math" "net/http" "net/http/httptest" "os" @@ -272,6 +273,73 @@ func Test_Routes(t *testing.T) { assert.Equal(t, "library", retrieveResp.OwnedBy) }, }, + { + Name: "Embed Handler Empty Input", + Method: http.MethodPost, + Path: "/api/embed", + Setup: func(t *testing.T, req *http.Request) { + embedReq := api.EmbedRequest{ + Model: "t-bone", + Input: "", + } + jsonData, err := json.Marshal(embedReq) + require.NoError(t, err) + req.Body = io.NopCloser(bytes.NewReader(jsonData)) + }, + Expected: func(t *testing.T, resp *http.Response) { + contentType := resp.Header.Get("Content-Type") + if contentType != "application/json; charset=utf-8" { + t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + var embedResp api.EmbedResponse + err = json.Unmarshal(body, &embedResp) + if err != nil { + t.Fatal(err) + } + + if embedResp.Model != "t-bone" { + t.Fatalf("expected model t-bone, got %s", embedResp.Model) + } + + if embedResp.Embeddings != nil { + t.Fatalf("expected embeddings to be nil, got %v", embedResp.Embeddings) + } + }, + }, + { + Name: "Embed Handler Invalid Input", + Method: http.MethodPost, + Path: "/api/embed", + Setup: func(t *testing.T, req *http.Request) { + embedReq := api.EmbedRequest{ + Model: "t-bone", + Input: 2, + } + jsonData, err := json.Marshal(embedReq) + require.NoError(t, err) + req.Body = io.NopCloser(bytes.NewReader(jsonData)) + }, + Expected: func(t *testing.T, resp *http.Response) { + contentType := resp.Header.Get("Content-Type") + if contentType != "application/json; charset=utf-8" { + t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType) + } + _, err := io.ReadAll(resp.Body) + + if err != nil { + t.Fatal(err) + } + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected status code 400, got %d", resp.StatusCode) + } + }, + }, } t.Setenv("OLLAMA_MODELS", t.TempDir()) @@ -420,3 +488,38 @@ func TestShow(t *testing.T) { t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"]) } } + +func TestNormalize(t *testing.T) { + type testCase struct { + input []float32 + } + + testCases := []testCase{ + {input: []float32{1}}, + {input: []float32{0, 1, 2, 3}}, + {input: []float32{0.1, 0.2, 0.3}}, + {input: []float32{-0.1, 0.2, 0.3, -0.4}}, + {input: []float32{0, 0, 0}}, + } + + isNormalized := func(vec []float32) (res bool) { + sum := 0.0 + for _, v := range vec { + sum += float64(v * v) + } + if math.Abs(sum-1) > 1e-6 { + return sum == 0 + } else { + return true + } + } + + for _, tc := range testCases { + t.Run("", func(t *testing.T) { + normalized := normalize(tc.input) + if !isNormalized(normalized) { + t.Errorf("Vector %v is not normalized", tc.input) + } + }) + } +} diff --git a/server/sched_test.go b/server/sched_test.go index 3fbd188a7..4b000331e 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -642,8 +642,8 @@ type mockLlm struct { pingResp error waitResp error completionResp error - embeddingResp []float64 - embeddingRespErr error + embedResp [][]float32 + embedRespErr error tokenizeResp []int tokenizeRespErr error detokenizeResp string @@ -660,8 +660,8 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { return s.completionResp } -func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) { - return s.embeddingResp, s.embeddingRespErr +func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) { + return s.embedResp, s.embedRespErr } func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) { return s.tokenizeResp, s.tokenizeRespErr From 9e35d9bbee4c96ca064bcb7eadc5b2eb3a200ce7 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Mon, 15 Jul 2024 13:55:57 -0700 Subject: [PATCH 14/25] server: lowercase roles for compatibility with clients (#5695) --- api/types.go | 16 ++++++++++++++-- api/types_test.go | 23 +++++++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/api/types.go b/api/types.go index bf5529283..3b607cecb 100644 --- a/api/types.go +++ b/api/types.go @@ -110,6 +110,18 @@ type Message struct { Images []ImageData `json:"images,omitempty"` } +func (m *Message) UnmarshalJSON(b []byte) error { + type Alias Message + var a Alias + if err := json.Unmarshal(b, &a); err != nil { + return err + } + + *m = Message(a) + m.Role = strings.ToLower(m.Role) + return nil +} + // ChatResponse is the response returned by [Client.Chat]. Its fields are // similar to [GenerateResponse]. type ChatResponse struct { @@ -243,8 +255,8 @@ type DeleteRequest struct { // ShowRequest is the request passed to [Client.Show]. type ShowRequest struct { - Model string `json:"model"` - System string `json:"system"` + Model string `json:"model"` + System string `json:"system"` // Template is deprecated Template string `json:"template"` diff --git a/api/types_test.go b/api/types_test.go index c60ed90e0..4699c1503 100644 --- a/api/types_test.go +++ b/api/types_test.go @@ -208,3 +208,26 @@ func TestUseMmapFormatParams(t *testing.T) { }) } } + +func TestMessage_UnmarshalJSON(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {`{"role": "USER", "content": "Hello!"}`, "user"}, + {`{"role": "System", "content": "Initialization complete."}`, "system"}, + {`{"role": "assistant", "content": "How can I help you?"}`, "assistant"}, + {`{"role": "TOOl", "content": "Access granted."}`, "tool"}, + } + + for _, test := range tests { + var msg Message + if err := json.Unmarshal([]byte(test.input), &msg); err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if msg.Role != test.expected { + t.Errorf("role not lowercased: got %v, expected %v", msg.Role, test.expected) + } + } +} From 224337b32f26e813ab0cf4a6544c1316715a5d41 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Mon, 15 Jul 2024 15:10:22 -0700 Subject: [PATCH 15/25] Bump linux ROCm to 6.1.2 --- .github/workflows/test.yaml | 2 +- Dockerfile | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 977d8da14..90fef6e59 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -126,7 +126,7 @@ jobs: strategy: matrix: rocm-version: - - '6.1.1' + - '6.1.2' runs-on: linux container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }} steps: diff --git a/Dockerfile b/Dockerfile index b2c5c4a2f..ca3934964 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,7 +2,7 @@ ARG GOLANG_VERSION=1.22.1 ARG CMAKE_VERSION=3.22.1 # this CUDA_VERSION corresponds with the one specified in docs/gpu.md ARG CUDA_VERSION=11.3.1 -ARG ROCM_VERSION=6.1.1 +ARG ROCM_VERSION=6.1.2 # Copy the minimal context we need to run the generate scripts FROM scratch AS llm-code From d02bbebb11c2e9c391ee3af30ba3437e67d1b7a8 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 20 Jun 2024 13:45:47 -0700 Subject: [PATCH 16/25] tools --- api/types.go | 39 ++++++++++++- server/images.go | 11 +++- server/model.go | 105 ++++++++++++++++++++++++++++++++++ server/prompt.go | 6 +- server/prompt_test.go | 2 +- server/routes.go | 24 ++++++-- template/template.go | 129 +++++++++++++++++++++++++++++------------- 7 files changed, 263 insertions(+), 53 deletions(-) diff --git a/api/types.go b/api/types.go index 3b607cecb..97af4aed0 100644 --- a/api/types.go +++ b/api/types.go @@ -97,6 +97,9 @@ type ChatRequest struct { // followin the request. KeepAlive *Duration `json:"keep_alive,omitempty"` + // Tools is an optional list of tools the model has access to. + Tools []Tool `json:"tools,omitempty"` + // Options lists model-specific options. Options map[string]interface{} `json:"options"` } @@ -105,9 +108,36 @@ type ChatRequest struct { // role ("system", "user", or "assistant"), the content and an optional list // of images. type Message struct { - Role string `json:"role"` - Content string `json:"content"` - Images []ImageData `json:"images,omitempty"` + Role string `json:"role"` + Content string `json:"content,omitempty"` + Images []ImageData `json:"images,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` +} + +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type"` + Function struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + } `json:"function"` +} + +type Tool struct { + Type string `json:"type"` + Function struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters struct { + Type string `json:"type"` + Required []string `json:"required"` + Properties map[string]struct { + Type string `json:"type"` + Description string `json:"description"` + Enum []string `json:"enum,omitempty"` + } `json:"properties"` + } `json:"parameters"` + } `json:"function"` } func (m *Message) UnmarshalJSON(b []byte) error { @@ -374,6 +404,9 @@ type GenerateResponse struct { // Response is the textual response itself. Response string `json:"response"` + // ToolCalls is the list of tools the model wants to call + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + // Done specifies if the response is complete. Done bool `json:"done"` diff --git a/server/images.go b/server/images.go index 688d5dcae..1b87888ed 100644 --- a/server/images.go +++ b/server/images.go @@ -38,7 +38,10 @@ var errCapabilityCompletion = errors.New("completion") type Capability string -const CapabilityCompletion = Capability("completion") +const ( + CapabilityCompletion = Capability("completion") + CapabilityTools = Capability("tools") +) type registryOptions struct { Insecure bool @@ -88,6 +91,10 @@ func (m *Model) CheckCapabilities(caps ...Capability) error { if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok { errs = append(errs, errCapabilityCompletion) } + case CapabilityTools: + if !slices.Contains(m.Template.Vars(), "tools") { + errs = append(errs, errors.New("tools")) + } default: slog.Error("unknown capability", "capability", cap) return fmt.Errorf("unknown capability: %s", cap) @@ -95,7 +102,7 @@ func (m *Model) CheckCapabilities(caps ...Capability) error { } if err := errors.Join(errs...); err != nil { - return fmt.Errorf("missing capabilities: %w", errors.Join(errs...)) + return fmt.Errorf("does not support %w", errors.Join(errs...)) } return nil diff --git a/server/model.go b/server/model.go index a79f549a3..be318db9c 100644 --- a/server/model.go +++ b/server/model.go @@ -4,6 +4,7 @@ import ( "archive/zip" "bytes" "context" + "encoding/json" "errors" "fmt" "io" @@ -11,7 +12,11 @@ import ( "net/http" "os" "path/filepath" + "slices" + "strings" + "text/template/parse" + "github.com/google/uuid" "github.com/ollama/ollama/api" "github.com/ollama/ollama/convert" "github.com/ollama/ollama/llm" @@ -289,3 +294,103 @@ func detectContentType(r io.Reader) (string, error) { return "unknown", nil } + +// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls. +// mxyng: this only really works if the input contains tool calls in some JSON format +func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { + // create a subtree from the node that ranges over .ToolCalls + tmpl := m.Template.Subtree(func(n parse.Node) bool { + if t, ok := n.(*parse.RangeNode); ok { + return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls") + } + + return false + }) + + if tmpl == nil { + return nil, false + } + + var b bytes.Buffer + if err := tmpl.Execute(&b, map[string][]map[string]any{ + "ToolCalls": { + { + "Function": map[string]any{ + "Name": "@@name@@", + "Arguments": "@@arguments@@", + }, + }, + }, + }); err != nil { + return nil, false + } + + var kv map[string]string + // execute the subtree with placeholders to identify the keys + if err := json.Unmarshal(b.Bytes(), &kv); err != nil { + return nil, false + } + + // find the keys that correspond to the name and arguments fields + var name, arguments string + for k, v := range kv { + switch v { + case "@@name@@": + name = k + case "@@arguments@@": + arguments = k + } + } + + var sm []map[string]any + decoder := json.NewDecoder(strings.NewReader(s)) + for { + // incrementally decode the JSON into a list of JSON objects + // skipping over any invalid tokens + if err := decoder.Decode(&sm); err != nil { + if errors.Is(err, io.EOF) { + break + } + + if errors.As(err, new(*json.SyntaxError)) { + r := decoder.Buffered() + if _, err := r.Read(make([]byte, decoder.InputOffset()+1)); err != nil { + break + } + + decoder = json.NewDecoder(r) + continue + } + + return nil, false + } + + // break as soon as a valid object is decoded + break + } + + var toolCalls []api.ToolCall + for _, kv := range sm { + call := api.ToolCall{ + ID: uuid.New().String(), + Type: "function", + } + + for k, v := range kv { + switch k { + case name: + call.Function.Name = v.(string) + case arguments: + call.Function.Arguments = v.(map[string]any) + } + } + + toolCalls = append(toolCalls, call) + } + + if len(toolCalls) > 0 { + return toolCalls, true + } + + return nil, false +} diff --git a/server/prompt.go b/server/prompt.go index abc5e61e1..be0d49692 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -15,7 +15,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error) // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn. // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the // latest message and 2) system messages -func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) { +func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) { var system []api.Message // always include the last message n := len(msgs) - 1 @@ -29,7 +29,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. } var b bytes.Buffer - if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil { + if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil { return "", nil, err } @@ -57,7 +57,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. // truncate any messages that do not fit into the context window var b bytes.Buffer - if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil { + if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil { return "", nil, err } diff --git a/server/prompt_test.go b/server/prompt_test.go index d8caf3ed2..9c4da0685 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -192,7 +192,7 @@ func TestChatPrompt(t *testing.T) { t.Run(tt.name, func(t *testing.T) { model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}} opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}} - prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs) + prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs, nil) if err != nil { t.Fatal(err) } diff --git a/server/routes.go b/server/routes.go index c5c3a19ca..9712d8950 100644 --- a/server/routes.go +++ b/server/routes.go @@ -265,6 +265,11 @@ func (s *Server) GenerateHandler(c *gin.Context) { } r.Response = sb.String() + if toolCalls, ok := m.parseToolCalls(sb.String()); ok { + r.ToolCalls = toolCalls + r.Response = "" + } + c.JSON(http.StatusOK, r) return } @@ -1279,6 +1284,10 @@ func (s *Server) ChatHandler(c *gin.Context) { } caps := []Capability{CapabilityCompletion} + if req.Tools != nil { + caps = append(caps, CapabilityTools) + } + r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) if errors.Is(err, errCapabilityCompletion) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)}) @@ -1305,7 +1314,7 @@ func (s *Server) ChatHandler(c *gin.Context) { req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...) } - prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages) + prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages, req.Tools) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -1348,13 +1357,13 @@ func (s *Server) ChatHandler(c *gin.Context) { }() if req.Stream != nil && !*req.Stream { - var r api.ChatResponse + var resp api.ChatResponse var sb strings.Builder for rr := range ch { switch t := rr.(type) { case api.ChatResponse: sb.WriteString(t.Message.Content) - r = t + resp = t case gin.H: msg, ok := t["error"].(string) if !ok { @@ -1369,8 +1378,13 @@ func (s *Server) ChatHandler(c *gin.Context) { } } - r.Message.Content = sb.String() - c.JSON(http.StatusOK, r) + resp.Message.Content = sb.String() + if toolCalls, ok := m.parseToolCalls(sb.String()); ok { + resp.Message.ToolCalls = toolCalls + resp.Message.Content = "" + } + + c.JSON(http.StatusOK, resp) return } diff --git a/template/template.go b/template/template.go index 90014ec1a..0e23cf1ce 100644 --- a/template/template.go +++ b/template/template.go @@ -13,6 +13,7 @@ import ( "sync" "text/template" "text/template/parse" + "time" "github.com/agnivade/levenshtein" "github.com/ollama/ollama/api" @@ -102,8 +103,18 @@ var response = parse.ActionNode{ }, } +var funcs = template.FuncMap{ + "json": func(v any) string { + b, _ := json.Marshal(v) + return string(b) + }, + "now": func() string { + return time.Now().Format("2006-01-02 15:04:05") + }, +} + func Parse(s string) (*Template, error) { - tmpl := template.New("").Option("missingkey=zero") + tmpl := template.New("").Option("missingkey=zero").Funcs(funcs) tmpl, err := tmpl.Parse(s) if err != nil { @@ -127,7 +138,7 @@ func (t *Template) Vars() []string { var vars []string for _, tt := range t.Templates() { for _, n := range tt.Root.Nodes { - vars = append(vars, parseNode(n)...) + vars = append(vars, Identifiers(n)...) } } @@ -143,17 +154,65 @@ func (t *Template) Vars() []string { type Values struct { Messages []api.Message + Tools []api.Tool // forceLegacy is a flag used to test compatibility with legacy templates forceLegacy bool } +func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template { + var walk func(parse.Node) parse.Node + walk = func(n parse.Node) parse.Node { + if fn(n) { + return n + } + + switch t := n.(type) { + case *parse.ListNode: + for _, c := range t.Nodes { + if n := walk(c); n != nil { + return n + } + } + case *parse.BranchNode: + for _, n := range []*parse.ListNode{t.List, t.ElseList} { + if n != nil { + if n := walk(n); n != nil { + return n + } + } + } + case *parse.IfNode: + return walk(&t.BranchNode) + case *parse.WithNode: + return walk(&t.BranchNode) + case *parse.RangeNode: + return walk(&t.BranchNode) + } + + return nil + } + + if n := walk(t.Tree.Root); n != nil { + return (&template.Template{ + Tree: &parse.Tree{ + Root: &parse.ListNode{ + Nodes: []parse.Node{n}, + }, + }, + }).Funcs(funcs) + } + + return nil +} + func (t *Template) Execute(w io.Writer, v Values) error { system, messages := collate(v.Messages) if !v.forceLegacy && slices.Contains(t.Vars(), "messages") { return t.Template.Execute(w, map[string]any{ "System": system, "Messages": messages, + "Tools": v.Tools, }) } @@ -161,7 +220,7 @@ func (t *Template) Execute(w io.Writer, v Values) error { var b bytes.Buffer var prompt, response string for _, m := range messages { - execute := func () error { + execute := func() error { if err := t.Template.Execute(&b, map[string]any{ "System": system, "Prompt": prompt, @@ -198,12 +257,8 @@ func (t *Template) Execute(w io.Writer, v Values) error { var cut bool nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool { - switch t := n.(type) { - case *parse.ActionNode: - case *parse.FieldNode: - if slices.Contains(t.Ident, "Response") { - cut = true - } + if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") { + cut = true } return cut @@ -255,50 +310,46 @@ func collate(msgs []api.Message) (string, []*api.Message) { return strings.Join(system, "\n\n"), collated } -func parseNode(n parse.Node) []string { +// Identifiers walks the node tree returning any identifiers it finds along the way +func Identifiers(n parse.Node) []string { switch n := n.(type) { + case *parse.ListNode: + var names []string + for _, n := range n.Nodes { + names = append(names, Identifiers(n)...) + } + + return names + case *parse.TemplateNode: + return Identifiers(n.Pipe) case *parse.ActionNode: - return parseNode(n.Pipe) + return Identifiers(n.Pipe) + case *parse.BranchNode: + names := Identifiers(n.Pipe) + for _, n := range []*parse.ListNode{n.List, n.ElseList} { + if n != nil { + names = append(names, Identifiers(n)...) + } + } + return names case *parse.IfNode: - names := parseNode(n.Pipe) - names = append(names, parseNode(n.List)...) - if n.ElseList != nil { - names = append(names, parseNode(n.ElseList)...) - } - return names + return Identifiers(&n.BranchNode) case *parse.RangeNode: - names := parseNode(n.Pipe) - names = append(names, parseNode(n.List)...) - if n.ElseList != nil { - names = append(names, parseNode(n.ElseList)...) - } - return names + return Identifiers(&n.BranchNode) case *parse.WithNode: - names := parseNode(n.Pipe) - names = append(names, parseNode(n.List)...) - if n.ElseList != nil { - names = append(names, parseNode(n.ElseList)...) - } - return names + return Identifiers(&n.BranchNode) case *parse.PipeNode: var names []string for _, c := range n.Cmds { for _, a := range c.Args { - names = append(names, parseNode(a)...) + names = append(names, Identifiers(a)...) } } - return names - case *parse.ListNode: - var names []string - for _, n := range n.Nodes { - names = append(names, parseNode(n)...) - } - return names case *parse.FieldNode: return n.Ident - case *parse.TemplateNode: - return parseNode(n.Pipe) + case *parse.VariableNode: + return n.Ident } return nil From ef5136a745138896d080bf5bcac13377f7672b77 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 15 Jul 2024 12:17:38 -0700 Subject: [PATCH 17/25] tools test --- server/model_test.go | 122 ++++++++++++++++++++ server/testdata/tools/command-r-plus.gotmpl | 67 +++++++++++ server/testdata/tools/command-r-plus.out | 39 +++++++ server/testdata/tools/firefunction.gotmpl | 31 +++++ server/testdata/tools/firefunction.out | 17 +++ server/testdata/tools/messages.json | 39 +++++++ server/testdata/tools/mistral.gotmpl | 15 +++ server/testdata/tools/mistral.out | 3 + server/testdata/tools/tools.json | 30 +++++ template/template.go | 4 - 10 files changed, 363 insertions(+), 4 deletions(-) create mode 100644 server/testdata/tools/command-r-plus.gotmpl create mode 100644 server/testdata/tools/command-r-plus.out create mode 100644 server/testdata/tools/firefunction.gotmpl create mode 100644 server/testdata/tools/firefunction.out create mode 100644 server/testdata/tools/messages.json create mode 100644 server/testdata/tools/mistral.gotmpl create mode 100644 server/testdata/tools/mistral.out create mode 100644 server/testdata/tools/tools.json diff --git a/server/model_test.go b/server/model_test.go index a383b7e72..025781928 100644 --- a/server/model_test.go +++ b/server/model_test.go @@ -3,7 +3,9 @@ package server import ( "archive/zip" "bytes" + "encoding/json" "errors" + "fmt" "io" "os" "path/filepath" @@ -11,7 +13,9 @@ import ( "strings" "testing" + "github.com/google/go-cmp/cmp" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/template" ) func createZipFile(t *testing.T, name string) *os.File { @@ -110,3 +114,121 @@ func TestExtractFromZipFile(t *testing.T) { }) } } + +type function struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` +} + +func readFile(t *testing.T, base, name string) *bytes.Buffer { + t.Helper() + + bts, err := os.ReadFile(filepath.Join(base, name)) + if err != nil { + t.Fatal(err) + } + + return bytes.NewBuffer(bts) +} + +func TestExecuteWithTools(t *testing.T) { + p := filepath.Join("testdata", "tools") + cases := []struct { + model string + output string + }{ + {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`}, + {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] + +The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`}, + {"command-r-plus", "Action: ```json" + ` +[ + { + "tool_name": "get_current_weather", + "parameters": { + "format": "fahrenheit", + "location": "San Francisco, CA" + } + }, + { + "tool_name": "get_current_weather", + "parameters": { + "format": "celsius", + "location": "Toronto, Canada" + } + } +] +` + "```"}, + {"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`}, + } + + var tools []api.Tool + if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil { + t.Fatal(err) + } + + var messages []api.Message + if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil { + t.Fatal(err) + } + + calls := []api.ToolCall{ + { + Type: "function", + Function: function{ + Name: "get_current_weather", + Arguments: map[string]any{ + "format": "fahrenheit", + "location": "San Francisco, CA", + }, + }, + }, + { + Type: "function", + Function: function{ + Name: "get_current_weather", + Arguments: map[string]any{ + "format": "celsius", + "location": "Toronto, Canada", + }, + }, + }, + } + + for _, tt := range cases { + t.Run(tt.model, func(t *testing.T) { + tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String()) + if err != nil { + t.Fatal(err) + } + + t.Run("template", func(t *testing.T) { + var actual bytes.Buffer + if err := tmpl.Execute(&actual, template.Values{Tools: tools, Messages: messages}); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + + t.Run("parse", func(t *testing.T) { + m := &Model{Template: tmpl} + actual, ok := m.parseToolCalls(tt.output) + if !ok { + t.Fatal("failed to parse tool calls") + } + + for i := range actual { + // ID is randomly generated so clear it for comparison + actual[i].ID = "" + } + + if diff := cmp.Diff(actual, calls); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + }) + } +} diff --git a/server/testdata/tools/command-r-plus.gotmpl b/server/testdata/tools/command-r-plus.gotmpl new file mode 100644 index 000000000..088a4f0e5 --- /dev/null +++ b/server/testdata/tools/command-r-plus.gotmpl @@ -0,0 +1,67 @@ +{{- if or .Tools .System }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|> +{{- if .Tools }}# Safety Preamble +The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral. + +# System Preamble +## Basic Rules +You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions. + +{{ if .System }}# User Preamble +{{ .System }} +{{- end }} + +## Available Tools +Here is a list of tools that you have available to you: +{{- range .Tools }} + +```python +def {{ .Function.Name }}( +{{- range $name, $property := .Function.Parameters.Properties }}{{ $name }}: {{ $property.Type }}, {{ end }}) -> List[Dict]: + """{{ .Function.Description }} + +{{- if .Function.Parameters.Properties }} + + Args: +{{- range $name, $property := .Function.Parameters.Properties }} + {{ $name }} ({{ $property.Type }}): {{ $property.Description }} +{{- end }} +{{- end }} + """ + pass +``` +{{- end }} +{{- else if .System }}{{ .System }} +{{- end }}<|END_OF_TURN_TOKEN|> +{{- end }} +{{- range .Messages }} +{{- if eq .Role "system" }} +{{- continue }} +{{- end }}<|START_OF_TURN_TOKEN|> +{{- if eq .Role "user" }}<|USER_TOKEN|>{{ .Content }} +{{- else if eq .Role "assistant" }}<|CHATBOT_TOKEN|> +{{- if .Content }}{{ .Content }} +{{- else if .ToolCalls }} +Action: ```json +[ +{{- range .ToolCalls }} + { + "tool_name": "{{ .Function.Name }}", + "parameters": {{ json .Function.Arguments }} + } +{{- end }} +]``` +{{ continue }} +{{ end }} +{{- else if eq .Role "tool" }}<|SYSTEM_TOKEN|> +{{ .Content }} +{{- end }}<|END_OF_TURN_TOKEN|> +{{- end }} +{{- if .Tools }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example: +```json +[ + { + "tool_name": title of the tool in the specification, + "parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters + } +]``` +{{- end }}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|> \ No newline at end of file diff --git a/server/testdata/tools/command-r-plus.out b/server/testdata/tools/command-r-plus.out new file mode 100644 index 000000000..425af75ab --- /dev/null +++ b/server/testdata/tools/command-r-plus.out @@ -0,0 +1,39 @@ +<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># Safety Preamble +The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral. + +# System Preamble +## Basic Rules +You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions. + +# User Preamble +You are a knowledgable assistant. You can answer questions and perform tasks. + +## Available Tools +Here is a list of tools that you have available to you: + +```python +def get_current_weather(format: string, location: string, ) -> List[Dict]: + """Get the current weather + + Args: + format (string): The temperature unit to use. Infer this from the users location. + location (string): The city and state, e.g. San Francisco, CA + """ + pass +```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in Paris?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|> +Action: ```json +[ + { + "tool_name": "get_current_weather", + "parameters": {"format":"celsius","location":"Paris, France"} + } +]``` +<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|> +22<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>The current temperature in Paris, France is 22 degrees Celsius.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in San Francisco and Toronto?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example: +```json +[ + { + "tool_name": title of the tool in the specification, + "parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters + } +]```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|> \ No newline at end of file diff --git a/server/testdata/tools/firefunction.gotmpl b/server/testdata/tools/firefunction.gotmpl new file mode 100644 index 000000000..bca88b3bd --- /dev/null +++ b/server/testdata/tools/firefunction.gotmpl @@ -0,0 +1,31 @@ +{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|> +{{- if .System }} +{{ .System }} +{{- end }} +In addition to plain text responses, you can chose to call one or more of the provided functions. + +Use the following rule to decide when to call a function: + * if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so + * if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls + +If you decide to call functions: + * prefix function calls with functools marker (no closing marker required) + * all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...] + * follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples + * respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0 + * make sure you pick the right functions that match the user intent + +Available functions as JSON spec: +{{- if .Tools }} +{{ json .Tools }} +{{- end }}<|eot_id|> +{{- end }} +{{- range .Messages }}<|start_header_id|> +{{- if or (eq .Role "user") (eq .Role "assistant") (eq .Role "tool") }}{{ .Role }} +{{- end }}<|end_header_id|> +{{- if .Content }}{{ .Content }} +{{- else if .ToolCalls }} functools[ +{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}{{ "}" }} +{{- end }}] +{{- end }}<|eot_id|> +{{- end }}<|start_header_id|>assistant<|end_header_id|> \ No newline at end of file diff --git a/server/testdata/tools/firefunction.out b/server/testdata/tools/firefunction.out new file mode 100644 index 000000000..be50175ef --- /dev/null +++ b/server/testdata/tools/firefunction.out @@ -0,0 +1,17 @@ +<|start_header_id|>system<|end_header_id|> +You are a knowledgable assistant. You can answer questions and perform tasks. +In addition to plain text responses, you can chose to call one or more of the provided functions. + +Use the following rule to decide when to call a function: + * if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so + * if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls + +If you decide to call functions: + * prefix function calls with functools marker (no closing marker required) + * all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...] + * follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples + * respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0 + * make sure you pick the right functions that match the user intent + +Available functions as JSON spec: +[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]<|eot_id|><|start_header_id|><|end_header_id|>You are a knowledgable assistant. You can answer questions and perform tasks.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|> functools[{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]<|eot_id|><|start_header_id|>tool<|end_header_id|>22<|eot_id|><|start_header_id|>assistant<|end_header_id|>The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|> \ No newline at end of file diff --git a/server/testdata/tools/messages.json b/server/testdata/tools/messages.json new file mode 100644 index 000000000..1a3d1f56c --- /dev/null +++ b/server/testdata/tools/messages.json @@ -0,0 +1,39 @@ +[ + { + "role": "system", + "content": "You are a knowledgable assistant. You can answer questions and perform tasks." + }, + { + "role": "user", + "content": "What's the weather like today in Paris?" + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "89a1e453-0bce-4de3-a456-c54bed09c520", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": { + "location": "Paris, France", + "format": "celsius" + } + } + } + ] + }, + { + "role": "tool", + "tool_call_id": "89a1e453-0bce-4de3-a456-c54bed09c520", + "content": "22" + }, + { + "role": "assistant", + "content": "The current temperature in Paris, France is 22 degrees Celsius." + }, + { + "role": "user", + "content": "What's the weather like today in San Francisco and Toronto?" + } +] diff --git a/server/testdata/tools/mistral.gotmpl b/server/testdata/tools/mistral.gotmpl new file mode 100644 index 000000000..a98bc7ad6 --- /dev/null +++ b/server/testdata/tools/mistral.gotmpl @@ -0,0 +1,15 @@ +{{- range $index, $_ := .Messages }} +{{- if eq .Role "user" }} +{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ json $.Tools }}[/AVAILABLE_TOOLS] +{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }} + +{{ end }}{{ .Content }}[/INST] +{{- else if eq .Role "assistant" }} +{{- if .Content }} {{ .Content }} +{{- else if .ToolCalls }}[TOOL_CALLS] [ +{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}} +{{- end }}] +{{- end }} +{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS] +{{- end }} +{{- end }} \ No newline at end of file diff --git a/server/testdata/tools/mistral.out b/server/testdata/tools/mistral.out new file mode 100644 index 000000000..31d8cdd62 --- /dev/null +++ b/server/testdata/tools/mistral.out @@ -0,0 +1,3 @@ +[INST] What's the weather like today in Paris?[/INST][TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}][TOOL_RESULTS] {"content": 22}[/TOOL_RESULTS] The current temperature in Paris, France is 22 degrees Celsius.[AVAILABLE_TOOLS] [{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}][/AVAILABLE_TOOLS][INST] You are a knowledgable assistant. You can answer questions and perform tasks. + +What's the weather like today in San Francisco and Toronto?[/INST] \ No newline at end of file diff --git a/server/testdata/tools/tools.json b/server/testdata/tools/tools.json new file mode 100644 index 000000000..17260bf83 --- /dev/null +++ b/server/testdata/tools/tools.json @@ -0,0 +1,30 @@ +[ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "format": { + "type": "string", + "enum": [ + "celsius", + "fahrenheit" + ], + "description": "The temperature unit to use. Infer this from the users location." + } + }, + "required": [ + "location", + "format" + ] + } + } + } +] diff --git a/template/template.go b/template/template.go index 0e23cf1ce..7cdb30ef1 100644 --- a/template/template.go +++ b/template/template.go @@ -13,7 +13,6 @@ import ( "sync" "text/template" "text/template/parse" - "time" "github.com/agnivade/levenshtein" "github.com/ollama/ollama/api" @@ -108,9 +107,6 @@ var funcs = template.FuncMap{ b, _ := json.Marshal(v) return string(b) }, - "now": func() string { - return time.Now().Format("2006-01-02 15:04:05") - }, } func Parse(s string) (*Template, error) { From 7ac6d462ecb9a26591b5f7457bea32c1cd63541f Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Mon, 15 Jul 2024 17:39:44 -0700 Subject: [PATCH 18/25] server: return empty slice on empty `/api/embed` request (#5713) * server: return empty slice on empty `/api/embed` request * fix tests --- api/types.go | 2 +- server/routes_test.go | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/api/types.go b/api/types.go index 3b607cecb..649860a1f 100644 --- a/api/types.go +++ b/api/types.go @@ -206,7 +206,7 @@ type EmbedRequest struct { // EmbedResponse is the response from [Client.Embed]. type EmbedResponse struct { Model string `json:"model"` - Embeddings [][]float32 `json:"embeddings,omitempty"` + Embeddings [][]float32 `json:"embeddings"` } // EmbeddingRequest is the request passed to [Client.Embeddings]. diff --git a/server/routes_test.go b/server/routes_test.go index 70622e9b0..97786ba2b 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -306,8 +306,12 @@ func Test_Routes(t *testing.T) { t.Fatalf("expected model t-bone, got %s", embedResp.Model) } - if embedResp.Embeddings != nil { - t.Fatalf("expected embeddings to be nil, got %v", embedResp.Embeddings) + if embedResp.Embeddings == nil { + t.Fatalf("expected embeddings to not be nil, got %v", embedResp.Embeddings) + } + + if len(embedResp.Embeddings) != 0 { + t.Fatalf("expected embeddings to be empty, got %v", embedResp.Embeddings) } }, }, From 4a565cbf9417ab6ec4560d334d556b5e97e23be9 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Sat, 13 Jul 2024 17:46:24 -0700 Subject: [PATCH 19/25] add chat and generate tests with mock runner --- llm/gguf.go | 1 + server/prompt_test.go | 15 +- server/routes_create_test.go | 18 + server/routes_delete_test.go | 5 + server/routes_generate_test.go | 651 +++++++++++++++++++++++++++++++++ server/routes_list_test.go | 3 + 6 files changed, 679 insertions(+), 14 deletions(-) create mode 100644 server/routes_generate_test.go diff --git a/llm/gguf.go b/llm/gguf.go index 4d343a1bd..a8427aed8 100644 --- a/llm/gguf.go +++ b/llm/gguf.go @@ -537,6 +537,7 @@ var ggufKVOrder = map[string][]string{ "tokenizer.ggml.add_bos_token", "tokenizer.ggml.add_eos_token", "tokenizer.chat_template", + "bert.pooling_type", }, } diff --git a/server/prompt_test.go b/server/prompt_test.go index 9c4da0685..02d23785f 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -3,7 +3,6 @@ package server import ( "bytes" "context" - "strings" "testing" "github.com/google/go-cmp/cmp" @@ -11,14 +10,6 @@ import ( "github.com/ollama/ollama/template" ) -func tokenize(_ context.Context, s string) (tokens []int, err error) { - for range strings.Fields(s) { - tokens = append(tokens, len(tokens)) - } - - return -} - func TestChatPrompt(t *testing.T) { type expect struct { prompt string @@ -192,15 +183,11 @@ func TestChatPrompt(t *testing.T) { t.Run(tt.name, func(t *testing.T) { model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}} opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}} - prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs, nil) + prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil) if err != nil { t.Fatal(err) } - if tt.prompt != prompt { - t.Errorf("expected %q, got %q", tt.prompt, prompt) - } - if diff := cmp.Diff(prompt, tt.prompt); diff != "" { t.Errorf("mismatch (-got +want):\n%s", diff) } diff --git a/server/routes_create_test.go b/server/routes_create_test.go index 04174b92e..cb548ebda 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -85,6 +85,8 @@ func checkFileExists(t *testing.T, p string, expect []string) { } func TestCreateFromBin(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) envconfig.LoadConfig() @@ -111,6 +113,8 @@ func TestCreateFromBin(t *testing.T) { } func TestCreateFromModel(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) envconfig.LoadConfig() @@ -152,6 +156,8 @@ func TestCreateFromModel(t *testing.T) { } func TestCreateRemovesLayers(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) envconfig.LoadConfig() @@ -199,6 +205,8 @@ func TestCreateRemovesLayers(t *testing.T) { } func TestCreateUnsetsSystem(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) envconfig.LoadConfig() @@ -255,6 +263,8 @@ func TestCreateUnsetsSystem(t *testing.T) { } func TestCreateMergeParameters(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) envconfig.LoadConfig() @@ -358,6 +368,8 @@ func TestCreateMergeParameters(t *testing.T) { } func TestCreateReplacesMessages(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) envconfig.LoadConfig() @@ -434,6 +446,8 @@ func TestCreateReplacesMessages(t *testing.T) { } func TestCreateTemplateSystem(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) envconfig.LoadConfig() @@ -480,6 +494,8 @@ func TestCreateTemplateSystem(t *testing.T) { } func TestCreateLicenses(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) envconfig.LoadConfig() @@ -526,6 +542,8 @@ func TestCreateLicenses(t *testing.T) { } func TestCreateDetectTemplate(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) envconfig.LoadConfig() diff --git a/server/routes_delete_test.go b/server/routes_delete_test.go index 00303bd17..33a97a73d 100644 --- a/server/routes_delete_test.go +++ b/server/routes_delete_test.go @@ -8,12 +8,15 @@ import ( "path/filepath" "testing" + "github.com/gin-gonic/gin" "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/types/model" ) func TestDelete(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) envconfig.LoadConfig() @@ -77,6 +80,8 @@ func TestDelete(t *testing.T) { } func TestDeleteDuplicateLayers(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) var s Server diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go new file mode 100644 index 000000000..9d8993284 --- /dev/null +++ b/server/routes_generate_test.go @@ -0,0 +1,651 @@ +package server + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/gpu" + "github.com/ollama/ollama/llm" +) + +type mockRunner struct { + llm.LlamaServer + + // CompletionRequest is only valid until the next call to Completion + llm.CompletionRequest + llm.CompletionResponse +} + +func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error { + m.CompletionRequest = r + fn(m.CompletionResponse) + return nil +} + +func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) { + for range strings.Fields(s) { + tokens = append(tokens, len(tokens)) + } + + return +} + +func newMockServer(mock *mockRunner) func(gpu.GpuInfoList, string, *llm.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) { + return func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, projectors, system []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { + return mock, nil + } +} + +func TestGenerateChat(t *testing.T) { + gin.SetMode(gin.TestMode) + + mock := mockRunner{ + CompletionResponse: llm.CompletionResponse{ + Done: true, + DoneReason: "stop", + PromptEvalCount: 1, + PromptEvalDuration: 1, + EvalCount: 1, + EvalDuration: 1, + }, + } + + s := Server{ + sched: &Scheduler{ + pendingReqCh: make(chan *LlmRequest, 1), + finishedReqCh: make(chan *LlmRequest, 1), + expiredCh: make(chan *runnerRef, 1), + unloadedCh: make(chan any, 1), + loaded: make(map[string]*runnerRef), + newServerFn: newMockServer(&mock), + getGpuFn: gpu.GetGPUInfo, + getCpuFn: gpu.GetCPUInfo, + reschedDelay: 250 * time.Millisecond, + loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) { + req.successCh <- &runnerRef{ + llama: &mock, + } + }, + }, + } + + go s.sched.Run(context.TODO()) + + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf(`FROM %s + TEMPLATE """ +{{- if .System }}System: {{ .System }} {{ end }} +{{- if .Prompt }}User: {{ .Prompt }} {{ end }} +{{- if .Response }}Assistant: {{ .Response }} {{ end }}""" +`, createBinFile(t, llm.KV{ + "general.architecture": "llama", + "llama.block_count": uint32(1), + "llama.context_length": uint32(8192), + "llama.embedding_length": uint32(4096), + "llama.attention.head_count": uint32(32), + "llama.attention.head_count_kv": uint32(8), + "tokenizer.ggml.tokens": []string{""}, + "tokenizer.ggml.scores": []float32{0}, + "tokenizer.ggml.token_type": []int32{0}, + }, []llm.Tensor{ + {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + })), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + t.Run("missing body", func(t *testing.T) { + w := createRequest(t, s.ChatHandler, nil) + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } + + if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + + t.Run("missing model", func(t *testing.T) { + w := createRequest(t, s.ChatHandler, api.ChatRequest{}) + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } + + if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + + t.Run("missing capabilities", func(t *testing.T) { + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "bert", + Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{ + "general.architecture": "bert", + "bert.pooling_type": uint32(0), + }, []llm.Tensor{})), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + w = createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "bert", + }) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } + + if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support chat"}`); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + + t.Run("load model", func(t *testing.T) { + w := createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "test", + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + var actual api.ChatResponse + if err := json.NewDecoder(w.Body).Decode(&actual); err != nil { + t.Fatal(err) + } + + if actual.Model != "test" { + t.Errorf("expected model test, got %s", actual.Model) + } + + if !actual.Done { + t.Errorf("expected done true, got false") + } + + if actual.DoneReason != "load" { + t.Errorf("expected done reason load, got %s", actual.DoneReason) + } + }) + + checkChatResponse := func(t *testing.T, body io.Reader, model, content string) { + t.Helper() + + var actual api.ChatResponse + if err := json.NewDecoder(body).Decode(&actual); err != nil { + t.Fatal(err) + } + + if actual.Model != model { + t.Errorf("expected model test, got %s", actual.Model) + } + + if !actual.Done { + t.Errorf("expected done false, got true") + } + + if actual.DoneReason != "stop" { + t.Errorf("expected done reason stop, got %s", actual.DoneReason) + } + + if diff := cmp.Diff(actual.Message, api.Message{ + Role: "assistant", + Content: content, + }); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + + if actual.PromptEvalCount == 0 { + t.Errorf("expected prompt eval count > 0, got 0") + } + + if actual.PromptEvalDuration == 0 { + t.Errorf("expected prompt eval duration > 0, got 0") + } + + if actual.EvalCount == 0 { + t.Errorf("expected eval count > 0, got 0") + } + + if actual.EvalDuration == 0 { + t.Errorf("expected eval duration > 0, got 0") + } + + if actual.LoadDuration == 0 { + t.Errorf("expected load duration > 0, got 0") + } + + if actual.TotalDuration == 0 { + t.Errorf("expected load duration > 0, got 0") + } + } + + mock.CompletionResponse.Content = "Hi!" + t.Run("messages", func(t *testing.T) { + w := createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "test", + Messages: []api.Message{ + {Role: "user", Content: "Hello!"}, + }, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + + checkChatResponse(t, w.Body, "test", "Hi!") + }) + + w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Model: "test-system", + Modelfile: "FROM test\nSYSTEM You are a helpful assistant.", + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + t.Run("messages with model system", func(t *testing.T) { + w := createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "test-system", + Messages: []api.Message{ + {Role: "user", Content: "Hello!"}, + }, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + + checkChatResponse(t, w.Body, "test-system", "Hi!") + }) + + mock.CompletionResponse.Content = "Abra kadabra!" + t.Run("messages with system", func(t *testing.T) { + w := createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "test-system", + Messages: []api.Message{ + {Role: "system", Content: "You can perform magic tricks."}, + {Role: "user", Content: "Hello!"}, + }, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + + checkChatResponse(t, w.Body, "test-system", "Abra kadabra!") + }) + + t.Run("messages with interleaved system", func(t *testing.T) { + w := createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "test-system", + Messages: []api.Message{ + {Role: "user", Content: "Hello!"}, + {Role: "assistant", Content: "I can help you with that."}, + {Role: "system", Content: "You can perform magic tricks."}, + {Role: "user", Content: "Help me write tests."}, + }, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + + checkChatResponse(t, w.Body, "test-system", "Abra kadabra!") + }) +} + +func TestGenerate(t *testing.T) { + gin.SetMode(gin.TestMode) + + mock := mockRunner{ + CompletionResponse: llm.CompletionResponse{ + Done: true, + DoneReason: "stop", + PromptEvalCount: 1, + PromptEvalDuration: 1, + EvalCount: 1, + EvalDuration: 1, + }, + } + + s := Server{ + sched: &Scheduler{ + pendingReqCh: make(chan *LlmRequest, 1), + finishedReqCh: make(chan *LlmRequest, 1), + expiredCh: make(chan *runnerRef, 1), + unloadedCh: make(chan any, 1), + loaded: make(map[string]*runnerRef), + newServerFn: newMockServer(&mock), + getGpuFn: gpu.GetGPUInfo, + getCpuFn: gpu.GetCPUInfo, + reschedDelay: 250 * time.Millisecond, + loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) { + req.successCh <- &runnerRef{ + llama: &mock, + } + }, + }, + } + + go s.sched.Run(context.TODO()) + + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf(`FROM %s + TEMPLATE """ +{{- if .System }}System: {{ .System }} {{ end }} +{{- if .Prompt }}User: {{ .Prompt }} {{ end }} +{{- if .Response }}Assistant: {{ .Response }} {{ end }}""" +`, createBinFile(t, llm.KV{ + "general.architecture": "llama", + "llama.block_count": uint32(1), + "llama.context_length": uint32(8192), + "llama.embedding_length": uint32(4096), + "llama.attention.head_count": uint32(32), + "llama.attention.head_count_kv": uint32(8), + "tokenizer.ggml.tokens": []string{""}, + "tokenizer.ggml.scores": []float32{0}, + "tokenizer.ggml.token_type": []int32{0}, + }, []llm.Tensor{ + {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + })), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + t.Run("missing body", func(t *testing.T) { + w := createRequest(t, s.GenerateHandler, nil) + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } + + if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + + t.Run("missing model", func(t *testing.T) { + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{}) + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } + + if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + + t.Run("missing capabilities", func(t *testing.T) { + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "bert", + Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{ + "general.architecture": "bert", + "bert.pooling_type": uint32(0), + }, []llm.Tensor{})), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + w = createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "bert", + }) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } + + if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support generate"}`); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + + t.Run("load model", func(t *testing.T) { + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test", + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + var actual api.GenerateResponse + if err := json.NewDecoder(w.Body).Decode(&actual); err != nil { + t.Fatal(err) + } + + if actual.Model != "test" { + t.Errorf("expected model test, got %s", actual.Model) + } + + if !actual.Done { + t.Errorf("expected done true, got false") + } + + if actual.DoneReason != "load" { + t.Errorf("expected done reason load, got %s", actual.DoneReason) + } + }) + + checkGenerateResponse := func(t *testing.T, body io.Reader, model, content string) { + t.Helper() + + var actual api.GenerateResponse + if err := json.NewDecoder(body).Decode(&actual); err != nil { + t.Fatal(err) + } + + if actual.Model != model { + t.Errorf("expected model test, got %s", actual.Model) + } + + if !actual.Done { + t.Errorf("expected done false, got true") + } + + if actual.DoneReason != "stop" { + t.Errorf("expected done reason stop, got %s", actual.DoneReason) + } + + if actual.Response != content { + t.Errorf("expected response %s, got %s", content, actual.Response) + } + + if actual.Context == nil { + t.Errorf("expected context not nil") + } + + if actual.PromptEvalCount == 0 { + t.Errorf("expected prompt eval count > 0, got 0") + } + + if actual.PromptEvalDuration == 0 { + t.Errorf("expected prompt eval duration > 0, got 0") + } + + if actual.EvalCount == 0 { + t.Errorf("expected eval count > 0, got 0") + } + + if actual.EvalDuration == 0 { + t.Errorf("expected eval duration > 0, got 0") + } + + if actual.LoadDuration == 0 { + t.Errorf("expected load duration > 0, got 0") + } + + if actual.TotalDuration == 0 { + t.Errorf("expected load duration > 0, got 0") + } + } + + mock.CompletionResponse.Content = "Hi!" + t.Run("prompt", func(t *testing.T) { + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test", + Prompt: "Hello!", + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + + checkGenerateResponse(t, w.Body, "test", "Hi!") + }) + + w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Model: "test-system", + Modelfile: "FROM test\nSYSTEM You are a helpful assistant.", + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + t.Run("prompt with model system", func(t *testing.T) { + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test-system", + Prompt: "Hello!", + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + + checkGenerateResponse(t, w.Body, "test-system", "Hi!") + }) + + mock.CompletionResponse.Content = "Abra kadabra!" + t.Run("prompt with system", func(t *testing.T) { + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test-system", + Prompt: "Hello!", + System: "You can perform magic tricks.", + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + + checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!") + }) + + t.Run("prompt with template", func(t *testing.T) { + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test-system", + Prompt: "Help me write tests.", + System: "You can perform magic tricks.", + Template: `{{- if .System }}{{ .System }} {{ end }} +{{- if .Prompt }}### USER {{ .Prompt }} {{ end }} +{{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + + checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!") + }) + + t.Run("raw", func(t *testing.T) { + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test-system", + Prompt: "Help me write tests.", + Raw: true, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) +} diff --git a/server/routes_list_test.go b/server/routes_list_test.go index d04be9d63..c2d9c1137 100644 --- a/server/routes_list_test.go +++ b/server/routes_list_test.go @@ -7,11 +7,14 @@ import ( "slices" "testing" + "github.com/gin-gonic/gin" "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" ) func TestList(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("OLLAMA_MODELS", t.TempDir()) envconfig.LoadConfig() From 4cb5d7decc08f4c7c136f81b2b5f1c74ebffbc62 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Tue, 16 Jul 2024 11:09:00 -0700 Subject: [PATCH 20/25] server: omit model system prompt if empty (#5717) --- server/routes.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/routes.go b/server/routes.go index 9712d8950..d0cbe6ccd 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1310,7 +1310,7 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - if req.Messages[0].Role != "system" { + if req.Messages[0].Role != "system" && m.System != "" { req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...) } From 5afbb60fc452965a4a53f1e46816ea41298269c6 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 16 Jul 2024 09:38:46 -0700 Subject: [PATCH 21/25] fix unmarshal type errors --- server/model.go | 48 +++++++++++++++++--------------------------- server/model_test.go | 33 +++++++++++++++++++----------- 2 files changed, 39 insertions(+), 42 deletions(-) diff --git a/server/model.go b/server/model.go index be318db9c..9e22d63a5 100644 --- a/server/model.go +++ b/server/model.go @@ -327,7 +327,8 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { var kv map[string]string // execute the subtree with placeholders to identify the keys - if err := json.Unmarshal(b.Bytes(), &kv); err != nil { + // trim any commands that might exist in the template + if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil { return nil, false } @@ -342,35 +343,26 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { } } - var sm []map[string]any - decoder := json.NewDecoder(strings.NewReader(s)) - for { - // incrementally decode the JSON into a list of JSON objects - // skipping over any invalid tokens - if err := decoder.Decode(&sm); err != nil { - if errors.Is(err, io.EOF) { - break - } - - if errors.As(err, new(*json.SyntaxError)) { - r := decoder.Buffered() - if _, err := r.Read(make([]byte, decoder.InputOffset()+1)); err != nil { - break - } - - decoder = json.NewDecoder(r) - continue - } - + var objs []map[string]any + for offset := 0; offset < len(s); { + if err := json.NewDecoder(strings.NewReader(s[offset:])).Decode(&objs); errors.Is(err, io.EOF) { + break + } else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) { + // skip over any syntax errors + offset += int(syntax.Offset) + } else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) { + // skip over any unmarshalable types + offset += int(unmarshalType.Offset) + } else if err != nil { return nil, false + } else { + // break when an object is decoded + break } - - // break as soon as a valid object is decoded - break } var toolCalls []api.ToolCall - for _, kv := range sm { + for _, kv := range objs { call := api.ToolCall{ ID: uuid.New().String(), Type: "function", @@ -388,9 +380,5 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { toolCalls = append(toolCalls, call) } - if len(toolCalls) > 0 { - return toolCalls, true - } - - return nil, false + return toolCalls, len(toolCalls) > 0 } diff --git a/server/model_test.go b/server/model_test.go index 025781928..d39f28911 100644 --- a/server/model_test.go +++ b/server/model_test.go @@ -136,11 +136,16 @@ func TestExecuteWithTools(t *testing.T) { cases := []struct { model string output string + ok bool }{ - {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`}, + {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] -The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`}, +The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true}, + {"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: + + [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, + {"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, {"command-r-plus", "Action: ```json" + ` [ { @@ -158,8 +163,10 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`} } } ] -` + "```"}, - {"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`}, +` + "```", true}, + {"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, + {"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, + {"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, } var tools []api.Tool @@ -216,17 +223,19 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`} t.Run("parse", func(t *testing.T) { m := &Model{Template: tmpl} actual, ok := m.parseToolCalls(tt.output) - if !ok { - t.Fatal("failed to parse tool calls") + if ok != tt.ok { + t.Fatalf("expected %t, got %t", tt.ok, ok) } - for i := range actual { - // ID is randomly generated so clear it for comparison - actual[i].ID = "" - } + if tt.ok { + for i := range actual { + // ID is randomly generated so clear it for comparison + actual[i].ID = "" + } - if diff := cmp.Diff(actual, calls); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) + if diff := cmp.Diff(actual, calls); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } } }) }) From 987dbab0b063653b2be71060449c8add7b76cc6e Mon Sep 17 00:00:00 2001 From: royjhan <65097070+royjhan@users.noreply.github.com> Date: Tue, 16 Jul 2024 13:36:08 -0700 Subject: [PATCH 22/25] OpenAI: /v1/embeddings compatibility (#5285) * OpenAI v1 models * Empty List Testing * Add back envconfig * v1/models docs * Remove Docs * OpenAI batch embed compatibility * merge conflicts * integrate with api/embed * ep * merge conflicts * request tests * rm resp test * merge conflict * merge conflict * test fixes * test fn renaming * input validation for empty string --------- Co-authored-by: jmorganca --- openai/openai.go | 111 ++++++++++++++++++++++++++++++++++++++++++ openai/openai_test.go | 72 +++++++++++++++++++++++++++ server/routes.go | 1 + 3 files changed, 184 insertions(+) diff --git a/openai/openai.go b/openai/openai.go index b289d73e8..88bdaec1f 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -61,6 +61,11 @@ type ResponseFormat struct { Type string `json:"type"` } +type EmbedRequest struct { + Input any `json:"input"` + Model string `json:"model"` +} + type ChatCompletionRequest struct { Model string `json:"model"` Messages []Message `json:"messages"` @@ -134,11 +139,23 @@ type Model struct { OwnedBy string `json:"owned_by"` } +type Embedding struct { + Object string `json:"object"` + Embedding []float32 `json:"embedding"` + Index int `json:"index"` +} + type ListCompletion struct { Object string `json:"object"` Data []Model `json:"data"` } +type EmbeddingList struct { + Object string `json:"object"` + Data []Embedding `json:"data"` + Model string `json:"model"` +} + func NewError(code int, message string) ErrorResponse { var etype string switch code { @@ -262,6 +279,27 @@ func toListCompletion(r api.ListResponse) ListCompletion { } } +func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList { + if r.Embeddings != nil { + var data []Embedding + for i, e := range r.Embeddings { + data = append(data, Embedding{ + Object: "embedding", + Embedding: e, + Index: i, + }) + } + + return EmbeddingList{ + Object: "list", + Data: data, + Model: model, + } + } + + return EmbeddingList{} +} + func toModel(r api.ShowResponse, m string) Model { return Model{ Id: m, @@ -465,6 +503,11 @@ type RetrieveWriter struct { model string } +type EmbedWriter struct { + BaseWriter + model string +} + func (w *BaseWriter) writeError(code int, data []byte) (int, error) { var serr api.StatusError err := json.Unmarshal(data, &serr) @@ -630,6 +673,33 @@ func (w *RetrieveWriter) Write(data []byte) (int, error) { return w.writeResponse(data) } +func (w *EmbedWriter) writeResponse(data []byte) (int, error) { + var embedResponse api.EmbedResponse + err := json.Unmarshal(data, &embedResponse) + + if err != nil { + return 0, err + } + + w.ResponseWriter.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse)) + + if err != nil { + return 0, err + } + + return len(data), nil +} + +func (w *EmbedWriter) Write(data []byte) (int, error) { + code := w.ResponseWriter.Status() + if code != http.StatusOK { + return w.writeError(code, data) + } + + return w.writeResponse(data) +} + func ListMiddleware() gin.HandlerFunc { return func(c *gin.Context) { w := &ListWriter{ @@ -693,6 +763,47 @@ func CompletionsMiddleware() gin.HandlerFunc { id: fmt.Sprintf("cmpl-%d", rand.Intn(999)), } + c.Writer = w + c.Next() + } +} + +func EmbeddingsMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + var req EmbedRequest + err := c.ShouldBindJSON(&req) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) + return + } + + if req.Input == "" { + req.Input = []string{""} + } + + if req.Input == nil { + c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input")) + return + } + + if v, ok := req.Input.([]any); ok && len(v) == 0 { + c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input")) + return + } + + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) + return + } + + c.Request.Body = io.NopCloser(&b) + + w := &EmbedWriter{ + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + model: req.Model, + } + c.Writer = w c.Next() diff --git a/openai/openai_test.go b/openai/openai_test.go index 99f8baaf3..5fc22b887 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -161,6 +161,78 @@ func TestMiddlewareRequests(t *testing.T) { } }, }, + { + Name: "embed handler single input", + Method: http.MethodPost, + Path: "/api/embed", + Handler: EmbeddingsMiddleware, + Setup: func(t *testing.T, req *http.Request) { + body := EmbedRequest{ + Input: "Hello", + Model: "test-model", + } + + bodyBytes, _ := json.Marshal(body) + + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + }, + Expected: func(t *testing.T, req *http.Request) { + var embedReq api.EmbedRequest + if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil { + t.Fatal(err) + } + + if embedReq.Input != "Hello" { + t.Fatalf("expected 'Hello', got %s", embedReq.Input) + } + + if embedReq.Model != "test-model" { + t.Fatalf("expected 'test-model', got %s", embedReq.Model) + } + }, + }, + { + Name: "embed handler batch input", + Method: http.MethodPost, + Path: "/api/embed", + Handler: EmbeddingsMiddleware, + Setup: func(t *testing.T, req *http.Request) { + body := EmbedRequest{ + Input: []string{"Hello", "World"}, + Model: "test-model", + } + + bodyBytes, _ := json.Marshal(body) + + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + }, + Expected: func(t *testing.T, req *http.Request) { + var embedReq api.EmbedRequest + if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil { + t.Fatal(err) + } + + input, ok := embedReq.Input.([]any) + + if !ok { + t.Fatalf("expected input to be a list") + } + + if input[0].(string) != "Hello" { + t.Fatalf("expected 'Hello', got %s", input[0]) + } + + if input[1].(string) != "World" { + t.Fatalf("expected 'World', got %s", input[1]) + } + + if embedReq.Model != "test-model" { + t.Fatalf("expected 'test-model', got %s", embedReq.Model) + } + }, + }, } gin.SetMode(gin.TestMode) diff --git a/server/routes.go b/server/routes.go index d0cbe6ccd..d22a099a4 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1064,6 +1064,7 @@ func (s *Server) GenerateRoutes() http.Handler { // Compatibility endpoints r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler) r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler) + r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler) r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler) r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler) From 5a83f79afdf14b60008120d6fbd4fe94ba3f5241 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 16 Jul 2024 13:48:38 -0700 Subject: [PATCH 23/25] remove unneeded tool calls --- api/types.go | 2 -- server/model.go | 7 +------ server/model_test.go | 7 ------- 3 files changed, 1 insertion(+), 15 deletions(-) diff --git a/api/types.go b/api/types.go index e670d1149..b18ee2287 100644 --- a/api/types.go +++ b/api/types.go @@ -115,8 +115,6 @@ type Message struct { } type ToolCall struct { - ID string `json:"id"` - Type string `json:"type"` Function struct { Name string `json:"name"` Arguments map[string]any `json:"arguments"` diff --git a/server/model.go b/server/model.go index 9e22d63a5..de65d6b61 100644 --- a/server/model.go +++ b/server/model.go @@ -16,7 +16,6 @@ import ( "strings" "text/template/parse" - "github.com/google/uuid" "github.com/ollama/ollama/api" "github.com/ollama/ollama/convert" "github.com/ollama/ollama/llm" @@ -363,11 +362,7 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { var toolCalls []api.ToolCall for _, kv := range objs { - call := api.ToolCall{ - ID: uuid.New().String(), - Type: "function", - } - + var call api.ToolCall for k, v := range kv { switch k { case name: diff --git a/server/model_test.go b/server/model_test.go index d39f28911..2e9dad3dd 100644 --- a/server/model_test.go +++ b/server/model_test.go @@ -181,7 +181,6 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, calls := []api.ToolCall{ { - Type: "function", Function: function{ Name: "get_current_weather", Arguments: map[string]any{ @@ -191,7 +190,6 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, }, }, { - Type: "function", Function: function{ Name: "get_current_weather", Arguments: map[string]any{ @@ -228,11 +226,6 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, } if tt.ok { - for i := range actual { - // ID is randomly generated so clear it for comparison - actual[i].ID = "" - } - if diff := cmp.Diff(actual, calls); diff != "" { t.Errorf("mismatch (-got +want):\n%s", diff) } From 97c20ede33ca67f8bf21e55ffe318831d83290d0 Mon Sep 17 00:00:00 2001 From: Thorsten Sommer Date: Tue, 16 Jul 2024 23:24:27 +0200 Subject: [PATCH 24/25] README: Added AI Studio to the list of UIs (#5721) * Added AI Studio to the list of UIs --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index eb5e85329..40f4aeded 100644 --- a/README.md +++ b/README.md @@ -294,6 +294,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama) - [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama) - [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS) +- [AI Studio](https://github.com/MindWorkAI/AI-Studio) ### Terminal From d290e87513664be8ca3120348614d124991ccb86 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 20 Jun 2024 19:13:36 -0700 Subject: [PATCH 25/25] add suffix support to generate endpoint this change is triggered by the presence of "suffix", particularly useful for code completion tasks --- api/types.go | 3 ++ server/images.go | 17 ++++++-- server/routes.go | 40 +++++++++++------- server/routes_generate_test.go | 77 ++++++++++++++++++++++++++++++---- template/template.go | 10 ++++- template/template_test.go | 35 ++++++++++++++++ 6 files changed, 155 insertions(+), 27 deletions(-) diff --git a/api/types.go b/api/types.go index e670d1149..3029fca8a 100644 --- a/api/types.go +++ b/api/types.go @@ -47,6 +47,9 @@ type GenerateRequest struct { // Prompt is the textual prompt to send to the model. Prompt string `json:"prompt"` + // Suffix is the text that comes after the inserted text. + Suffix string `json:"suffix"` + // System overrides the model's default system message/prompt. System string `json:"system"` diff --git a/server/images.go b/server/images.go index 1b87888ed..5e4e88583 100644 --- a/server/images.go +++ b/server/images.go @@ -34,13 +34,19 @@ import ( "github.com/ollama/ollama/version" ) -var errCapabilityCompletion = errors.New("completion") +var ( + errCapabilities = errors.New("does not support") + errCapabilityCompletion = errors.New("completion") + errCapabilityTools = errors.New("tools") + errCapabilityInsert = errors.New("insert") +) type Capability string const ( CapabilityCompletion = Capability("completion") CapabilityTools = Capability("tools") + CapabilityInsert = Capability("insert") ) type registryOptions struct { @@ -93,7 +99,12 @@ func (m *Model) CheckCapabilities(caps ...Capability) error { } case CapabilityTools: if !slices.Contains(m.Template.Vars(), "tools") { - errs = append(errs, errors.New("tools")) + errs = append(errs, errCapabilityTools) + } + case CapabilityInsert: + vars := m.Template.Vars() + if !slices.Contains(vars, "suffix") { + errs = append(errs, errCapabilityInsert) } default: slog.Error("unknown capability", "capability", cap) @@ -102,7 +113,7 @@ func (m *Model) CheckCapabilities(caps ...Capability) error { } if err := errors.Join(errs...); err != nil { - return fmt.Errorf("does not support %w", errors.Join(errs...)) + return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...)) } return nil diff --git a/server/routes.go b/server/routes.go index d22a099a4..c7f74fa40 100644 --- a/server/routes.go +++ b/server/routes.go @@ -122,6 +122,10 @@ func (s *Server) GenerateHandler(c *gin.Context) { } caps := []Capability{CapabilityCompletion} + if req.Suffix != "" { + caps = append(caps, CapabilityInsert) + } + r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) if errors.Is(err, errCapabilityCompletion) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)}) @@ -150,19 +154,6 @@ func (s *Server) GenerateHandler(c *gin.Context) { prompt := req.Prompt if !req.Raw { - var msgs []api.Message - if req.System != "" { - msgs = append(msgs, api.Message{Role: "system", Content: req.System}) - } else if m.System != "" { - msgs = append(msgs, api.Message{Role: "system", Content: m.System}) - } - - for _, i := range images { - msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)}) - } - - msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt}) - tmpl := m.Template if req.Template != "" { tmpl, err = template.Parse(req.Template) @@ -183,7 +174,26 @@ func (s *Server) GenerateHandler(c *gin.Context) { b.WriteString(s) } - if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil { + var values template.Values + if req.Suffix != "" { + values.Prompt = prompt + values.Suffix = req.Suffix + } else { + var msgs []api.Message + if req.System != "" { + msgs = append(msgs, api.Message{Role: "system", Content: req.System}) + } else if m.System != "" { + msgs = append(msgs, api.Message{Role: "system", Content: m.System}) + } + + for _, i := range images { + msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)}) + } + + values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt}) + } + + if err := tmpl.Execute(&b, values); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -1394,7 +1404,7 @@ func (s *Server) ChatHandler(c *gin.Context) { func handleScheduleError(c *gin.Context, name string, err error) { switch { - case errors.Is(err, errRequired): + case errors.Is(err, errCapabilities), errors.Is(err, errRequired): c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) case errors.Is(err, context.Canceled): c.JSON(499, gin.H{"error": "request canceled"}) diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index 9d8993284..c914b3006 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -73,6 +73,8 @@ func TestGenerateChat(t *testing.T) { getCpuFn: gpu.GetCPUInfo, reschedDelay: 250 * time.Millisecond, loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) { + // add 10ms delay to simulate loading + time.Sleep(10 * time.Millisecond) req.successCh <- &runnerRef{ llama: &mock, } @@ -83,7 +85,7 @@ func TestGenerateChat(t *testing.T) { go s.sched.Run(context.TODO()) w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ - Name: "test", + Model: "test", Modelfile: fmt.Sprintf(`FROM %s TEMPLATE """ {{- if .System }}System: {{ .System }} {{ end }} @@ -141,9 +143,9 @@ func TestGenerateChat(t *testing.T) { } }) - t.Run("missing capabilities", func(t *testing.T) { + t.Run("missing capabilities chat", func(t *testing.T) { w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ - Name: "bert", + Model: "bert", Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{ "general.architecture": "bert", "bert.pooling_type": uint32(0), @@ -243,7 +245,7 @@ func TestGenerateChat(t *testing.T) { } if actual.TotalDuration == 0 { - t.Errorf("expected load duration > 0, got 0") + t.Errorf("expected total duration > 0, got 0") } } @@ -379,7 +381,7 @@ func TestGenerate(t *testing.T) { go s.sched.Run(context.TODO()) w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ - Name: "test", + Model: "test", Modelfile: fmt.Sprintf(`FROM %s TEMPLATE """ {{- if .System }}System: {{ .System }} {{ end }} @@ -437,9 +439,9 @@ func TestGenerate(t *testing.T) { } }) - t.Run("missing capabilities", func(t *testing.T) { + t.Run("missing capabilities generate", func(t *testing.T) { w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ - Name: "bert", + Model: "bert", Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{ "general.architecture": "bert", "bert.pooling_type": uint32(0), @@ -464,6 +466,22 @@ func TestGenerate(t *testing.T) { } }) + t.Run("missing capabilities suffix", func(t *testing.T) { + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test", + Prompt: "def add(", + Suffix: " return c", + }) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } + + if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + t.Run("load model", func(t *testing.T) { w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ Model: "test", @@ -540,7 +558,7 @@ func TestGenerate(t *testing.T) { } if actual.TotalDuration == 0 { - t.Errorf("expected load duration > 0, got 0") + t.Errorf("expected total duration > 0, got 0") } } @@ -632,6 +650,49 @@ func TestGenerate(t *testing.T) { checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!") }) + w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Model: "test-suffix", + Modelfile: `FROM test +TEMPLATE """{{- if .Suffix }}
 {{ .Prompt }} {{ .Suffix }} 
+{{- else }}{{ .Prompt }}
+{{- end }}"""`,
+	})
+
+	if w.Code != http.StatusOK {
+		t.Fatalf("expected status 200, got %d", w.Code)
+	}
+
+	t.Run("prompt with suffix", func(t *testing.T) {
+		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+			Model:  "test-suffix",
+			Prompt: "def add(",
+			Suffix: "    return c",
+		})
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "
 def add(     return c "); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+
+	t.Run("prompt without suffix", func(t *testing.T) {
+		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+			Model:  "test-suffix",
+			Prompt: "def add(",
+		})
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+
 	t.Run("raw", func(t *testing.T) {
 		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
 			Model:  "test-system",
diff --git a/template/template.go b/template/template.go
index 7cdb30ef1..5330c0fa9 100644
--- a/template/template.go
+++ b/template/template.go
@@ -151,6 +151,8 @@ func (t *Template) Vars() []string {
 type Values struct {
 	Messages []api.Message
 	Tools    []api.Tool
+	Prompt   string
+	Suffix   string
 
 	// forceLegacy is a flag used to test compatibility with legacy templates
 	forceLegacy bool
@@ -204,7 +206,13 @@ func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
 
 func (t *Template) Execute(w io.Writer, v Values) error {
 	system, messages := collate(v.Messages)
-	if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
+	if v.Prompt != "" && v.Suffix != "" {
+		return t.Template.Execute(w, map[string]any{
+			"Prompt":   v.Prompt,
+			"Suffix":   v.Suffix,
+			"Response": "",
+		})
+	} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
 		return t.Template.Execute(w, map[string]any{
 			"System":   system,
 			"Messages": messages,
diff --git a/template/template_test.go b/template/template_test.go
index c678f1b12..ae0db80b9 100644
--- a/template/template_test.go
+++ b/template/template_test.go
@@ -359,3 +359,38 @@ Answer: `,
 		})
 	}
 }
+
+func TestExecuteWithSuffix(t *testing.T) {
+	tmpl, err := Parse(`{{- if .Suffix }}
 {{ .Prompt }} {{ .Suffix }} 
+{{- else }}{{ .Prompt }}
+{{- end }}`)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	cases := []struct {
+		name   string
+		values Values
+		expect string
+	}{
+		{
+			"message", Values{Messages: []api.Message{{Role: "user", Content: "hello"}}}, "hello",
+		},
+		{
+			"prompt suffix", Values{Prompt: "def add(", Suffix: "return x"}, "
 def add( return x ",
+		},
+	}
+
+	for _, tt := range cases {
+		t.Run(tt.name, func(t *testing.T) {
+			var b bytes.Buffer
+			if err := tmpl.Execute(&b, tt.values); err != nil {
+				t.Fatal(err)
+			}
+
+			if diff := cmp.Diff(b.String(), tt.expect); diff != "" {
+				t.Errorf("mismatch (-got +want):\n%s", diff)
+			}
+		})
+	}
+}