From 610054a2348160400541d2a6afbdebeb1c6ce53f Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Fri, 25 Apr 2025 16:35:16 -0700 Subject: [PATCH] model: support tools streaming and improve parsing --- server/model.go | 214 ++++++++------- server/model_test.go | 290 +++++++++++---------- server/routes.go | 91 ++++--- server/testdata/tools/qwen2.5-coder.gotmpl | 51 ++++ server/testdata/tools/qwen2.5-coder.out | 31 +++ server/tools.go | 175 +++++++++++++ server/tools_test.go | 281 ++++++++++++++++++++ 7 files changed, 858 insertions(+), 275 deletions(-) create mode 100644 server/testdata/tools/qwen2.5-coder.gotmpl create mode 100644 server/testdata/tools/qwen2.5-coder.out create mode 100644 server/tools.go create mode 100644 server/tools_test.go diff --git a/server/model.go b/server/model.go index 2149ff855..b5c91ef1e 100644 --- a/server/model.go +++ b/server/model.go @@ -12,6 +12,7 @@ import ( "os" "slices" "strings" + gotmpl "text/template" "text/template/parse" "github.com/ollama/ollama/api" @@ -129,33 +130,122 @@ func detectContentType(r io.Reader) (string, error) { return "unknown", nil } -func parseObjects(s string) []map[string]any { - var objs []map[string]any - for offset := 0; offset < len(s); { - var obj map[string]any - decoder := json.NewDecoder(strings.NewReader(s[offset:])) - if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { - break - } else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) { - // skip over any syntax errors - offset += int(syntax.Offset) - } else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) { - // skip over any unmarshalable types - offset += int(unmarshalType.Offset) - } else if err != nil { - return nil - } else { - offset += int(decoder.InputOffset()) - objs = append(objs, obj) +// extractToolCallsTemplate finds the immediate following text after any IfNode containing ".ToolCalls" +func extractToolCallsTemplate(tmpl *gotmpl.Template) (string, bool) { + if tmpl == nil || tmpl.Tree == nil { + slog.Debug("TextAfterToolCalls: template or tree is nil") + return "", false + } + + var result string + var found bool + + var walk func(nodes []parse.Node) + walk = func(nodes []parse.Node) { + for _, node := range nodes { + if found { + return + } + + switch n := node.(type) { + case *parse.IfNode: + if nodeContainsToolCalls(n) { + // Collect immediate TextNode(s) at start of IfNode's list + var sb strings.Builder + for _, innerNode := range n.List.Nodes { + if tn, ok := innerNode.(*parse.TextNode); ok { + sb.Write(tn.Text) + } else { + // Stop at first non-text node + break + } + } + result = sb.String() + found = true + return + } + // Recurse into child nodes + walk(n.List.Nodes) + if n.ElseList != nil { + walk(n.ElseList.Nodes) + } + case *parse.ListNode: + walk(n.Nodes) + case *parse.RangeNode: + walk(n.List.Nodes) + if n.ElseList != nil { + walk(n.ElseList.Nodes) + } + case *parse.WithNode: + walk(n.List.Nodes) + if n.ElseList != nil { + walk(n.ElseList.Nodes) + } + default: + // Continue to next node + continue + } + + if found { + return + } } } - return objs + walk(tmpl.Tree.Root.Nodes) + return result, found } -// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls. -// mxyng: this only really works if the input contains tool calls in some JSON format -func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { +// Helper to detect if a node's condition includes ".ToolCalls" +func nodeContainsToolCalls(n *parse.IfNode) bool { + for _, cmd := range n.Pipe.Cmds { + for _, arg := range cmd.Args { + if field, ok := arg.(*parse.FieldNode); ok { + if slices.Contains(field.Ident, "ToolCalls") { + return true + } + } + } + } + return false +} + +func ToolToken(tmpl *gotmpl.Template) (string, bool) { + tokenText, ok := extractToolCallsTemplate(tmpl) + if !ok { + return "", false + } + tokenText = strings.TrimSpace(tokenText) + if tokenText == "" { + return "", false + } + first := strings.Fields(tokenText)[0] + + start := -1 + end := -1 + for i, r := range tokenText { + if r == '<' || r == '[' { + start = i + } + if (r == '>' || r == ']') && start != -1 { + end = i + break + } + } + if start != -1 && end != -1 { + // return the token including the [ or < and the ] or > + return tokenText[start : end+1], true + } else if start != -1 { + // get until the [ or < - in the case tag was not closed + return tokenText[:start], true + } else if end != -1 { + // get after the ] or > - in the case tag was not opened + return tokenText[end+1:], true + } + return first, true +} + +func ToolTemplate(m *Model) (*gotmpl.Template, bool) { // create a subtree from the node that ranges over .ToolCalls tmpl := m.Template.Subtree(func(n parse.Node) bool { if t, ok := n.(*parse.RangeNode); ok { @@ -169,83 +259,5 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { return nil, false } - var b bytes.Buffer - if err := tmpl.Execute(&b, map[string][]api.ToolCall{ - "ToolCalls": { - { - Function: api.ToolCallFunction{ - Name: "@@name@@", - Arguments: api.ToolCallFunctionArguments{ - "@@argument@@": 1, - }, - }, - }, - }, - }); err != nil { - return nil, false - } - - templateObjects := parseObjects(b.String()) - if len(templateObjects) == 0 { - return nil, false - } - - // find the keys that correspond to the name and arguments fields - var name, arguments string - for k, v := range templateObjects[0] { - switch v.(type) { - case string: - name = k - case map[string]any: - arguments = k - } - } - - if name == "" || arguments == "" { - return nil, false - } - - responseObjects := parseObjects(s) - if len(responseObjects) == 0 { - return nil, false - } - - // collect all nested objects - var collect func(any) []map[string]any - collect = func(obj any) (all []map[string]any) { - switch o := obj.(type) { - case map[string]any: - all = append(all, o) - for _, v := range o { - all = append(all, collect(v)...) - } - case []any: - for _, v := range o { - all = append(all, collect(v)...) - } - } - - return all - } - - var objs []map[string]any - for _, p := range responseObjects { - objs = append(objs, collect(p)...) - } - - var toolCalls []api.ToolCall - for _, kv := range objs { - n, nok := kv[name].(string) - a, aok := kv[arguments].(map[string]any) - if nok && aok { - toolCalls = append(toolCalls, api.ToolCall{ - Function: api.ToolCallFunction{ - Name: n, - Arguments: a, - }, - }) - } - } - - return toolCalls, len(toolCalls) > 0 + return tmpl, true } diff --git a/server/model_test.go b/server/model_test.go index e5c2f2bb2..498fdb408 100644 --- a/server/model_test.go +++ b/server/model_test.go @@ -1,178 +1,184 @@ package server import ( - "bytes" - "encoding/json" - "fmt" - "os" - "path/filepath" "testing" - - "github.com/google/go-cmp/cmp" - - "github.com/ollama/ollama/api" - "github.com/ollama/ollama/template" + gotmpl "text/template" ) -func readFile(t *testing.T, base, name string) *bytes.Buffer { - t.Helper() - - bts, err := os.ReadFile(filepath.Join(base, name)) - if err != nil { - t.Fatal(err) - } - - return bytes.NewBuffer(bts) -} - -func TestExecuteWithTools(t *testing.T) { - p := filepath.Join("testdata", "tools") +func TestToolToken(t *testing.T) { cases := []struct { - model string - output string - ok bool + name string + template string + want string + ok bool }{ - {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, - {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] - -The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true}, - {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false}, - {"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: - - [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, - {"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, - {"command-r-plus", "Action: ```json" + ` -[ - { - "tool_name": "get_current_weather", - "parameters": { - "format": "fahrenheit", - "location": "San Francisco, CA" - } - }, - { - "tool_name": "get_current_weather", - "parameters": { - "format": "celsius", - "location": "Toronto, Canada" - } - } -] -` + "```", true}, - {"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, - {"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, - {"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, - {"llama3-groq-tool-use", ` -{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} -{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}} -`, true}, - {"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true}, - {"nemotron", `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]} `, true}, - } - - var tools []api.Tool - if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil { - t.Fatal(err) - } - - var messages []api.Message - if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil { - t.Fatal(err) - } - - calls := []api.ToolCall{ { - Function: api.ToolCallFunction{ - Name: "get_current_weather", - Arguments: api.ToolCallFunctionArguments{ - "format": "fahrenheit", - "location": "San Francisco, CA", - }, - }, + name: "basic tool call with action prefix", + template: "{{if .ToolCalls}}Action: ```json{{end}}", + want: "Action:", + ok: true, }, { - Function: api.ToolCallFunction{ - Name: "get_current_weather", - Arguments: api.ToolCallFunctionArguments{ - "format": "celsius", - "location": "Toronto, Canada", - }, - }, + name: "incomplete functools bracket", + template: "{{if .ToolCalls}}functools[{{end}}", + want: "functools", + ok: true, + }, + { + name: "tool call with angle brackets", + template: "{{if .ToolCalls}}Hello, world! {{end}}", + want: "", + ok: true, + }, + { + name: "multiple tool call formats", + template: "{{if .ToolCalls}}[tool_call] {{end}}", + want: "[tool_call]", + ok: true, + }, + { + name: "single angle bracket tool call", + template: "{{if .ToolCalls}}{{end}}", + want: "", + ok: true, + }, + { + name: "incomplete angle bracket after tool call", + template: "{{if .ToolCalls}}[tool_call] <{{end}}", + want: "[tool_call]", + ok: true, + }, + { + name: "angle bracket prefix with tool call", + template: "{{if .ToolCalls}}> {{end}}", + want: "", + ok: true, + }, + { + name: "uppercase tool call with incomplete bracket", + template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}", + want: "[TOOL_CALL]", + ok: true, + }, + { + name: "uppercase tool call with adjacent bracket", + template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}", + want: "[TOOL_CALL]", + ok: true, + }, + { + name: "tool call with pipe delimiters", + template: "{{if .ToolCalls}}<|tool_call|>{{end}}", + want: "<|tool_call|>", + ok: true, }, } for _, tt := range cases { - t.Run(tt.model, func(t *testing.T) { - tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String()) + t.Run(tt.name, func(t *testing.T) { + tmpl, err := gotmpl.New("test").Parse(tt.template) if err != nil { - t.Fatal(err) + t.Fatalf("failed to parse template: %v", err) + } + got, ok := ToolToken(tmpl) + if got != tt.want { + t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want) + } + if ok != tt.ok { + t.Errorf("ToolToken(%q) = %v; want %v", tt.template, ok, tt.ok) } - - t.Run("template", func(t *testing.T) { - var actual bytes.Buffer - if err := tmpl.Execute(&actual, template.Values{Tools: tools, Messages: messages}); err != nil { - t.Fatal(err) - } - - if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) - } - }) - - t.Run("parse", func(t *testing.T) { - m := &Model{Template: tmpl} - actual, ok := m.parseToolCalls(tt.output) - if ok != tt.ok { - t.Fatalf("expected %t, got %t", tt.ok, ok) - } - - if tt.ok { - if diff := cmp.Diff(actual, calls); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) - } - } - }) }) } } -func TestParseObjects(t *testing.T) { - tests := []struct { - input string - want []map[string]any +func TestTextAfterToolCalls(t *testing.T) { + cases := []struct { + name string + template string + want string + ok bool }{ { - input: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - want: []map[string]any{ - {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}}, - {"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, Canada"}}, - }, + name: "basic tool call with text after", + template: `{{if .ToolCalls}}tool response{{end}}`, + want: "tool response", + ok: true, }, { - input: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, - want: []map[string]any{ - {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}}, - }, + name: "tool call with mixed content after", + template: `{{if .ToolCalls}}{{.Something}}{{end}}`, + want: "", + ok: true, }, { - input: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, ON"}} `, - want: []map[string]any{ - {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}}, - {"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, ON"}}, - }, + name: "tool call with no text after", + template: `{{if .ToolCalls}}{{.Something}}{{end}}`, + want: "", + ok: true, }, { - input: `{"name": "get_current_weather", "arguments": `, - want: nil, + name: "nested tool call", + template: `{{if .Something}}{{if .ToolCalls}}[TOOL_CALL]{{end}}{{end}}`, + want: "[TOOL_CALL]", + ok: true, + }, + { + name: "no tool calls", + template: `{{if .Something}}no tools here{{end}}`, + want: "", + ok: false, + }, + { + name: "empty template", + template: ``, + want: "", + ok: false, + }, + { + name: "multiple tool calls sections", + template: `{{if .ToolCalls}}first{{end}}{{if .ToolCalls}}second{{end}}`, + want: "first", + ok: true, + }, + { + name: "range over tool calls", + template: `{{if .ToolCalls}}{{range .ToolCalls}}tool{{end}}{{end}}`, + want: "", + ok: true, + }, + { + name: "tool calls with pipe delimiters", + template: `{{if .ToolCalls}}<|tool|>{{end}}`, + want: "<|tool|>", + ok: true, + }, + { + name: "tool calls with nested template", + template: `{{if .ToolCalls}}{{template "tool" .}}{{end}}`, + want: "", + ok: true, + }, + { + name: "tool calls with whitespace variations", + template: `{{if .ToolCalls}} tool {{end}}`, + want: " tool ", + ok: true, }, } - for _, tc := range tests { - t.Run(tc.input, func(t *testing.T) { - got := parseObjects(tc.input) + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + tmpl, err := gotmpl.New("test").Parse(tt.template) + if err != nil { + t.Fatalf("failed to parse template: %v", err) + } - if diff := cmp.Diff(got, tc.want); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) + got, ok := extractToolCallsTemplate(tmpl) + if got != tt.want { + t.Errorf("TextAfterToolCalls() got = %q, want %q", got, tt.want) + } + if ok != tt.ok { + t.Errorf("TextAfterToolCalls() ok = %v, want %v", ok, tt.ok) } }) } diff --git a/server/routes.go b/server/routes.go index d0b8f487e..f09b632dd 100644 --- a/server/routes.go +++ b/server/routes.go @@ -21,6 +21,7 @@ import ( "slices" "strings" "syscall" + gotmpl "text/template" "time" "github.com/gin-contrib/cors" @@ -1487,6 +1488,24 @@ func (s *Server) ChatHandler(c *gin.Context) { defer close(ch) var sb strings.Builder var toolCallIndex int = 0 + var templateToolToken string + var tmpl *gotmpl.Template + if len(req.Tools) > 0 { + var ok bool + templateToolToken, ok = ToolToken(m.Template.Template) + if !ok { + slog.Debug("no tool token found") + } + tmpl, ok = ToolTemplate(m) + if !ok { + slog.Debug("no tool template found") + } + } + + checkToolCall := false + if len(req.Tools) > 0 { + checkToolCall = true + } if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: prompt, Images: images, @@ -1507,42 +1526,50 @@ func (s *Server) ChatHandler(c *gin.Context) { } if r.Done { + if sb.Len() > 0 { + res.Message.Content = sb.String() + } res.DoneReason = r.DoneReason.String() res.TotalDuration = time.Since(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart) } - // TODO: tool call checking and filtering should be moved outside of this callback once streaming - // however this was a simple change for now without reworking streaming logic of this (and other) - // handlers - if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 { - ch <- res - return - } - - // Streaming tool calls: - // If tools are recognized, use a flag to track the sending of a tool downstream - // This ensures that content is cleared from the message on the last chunk sent sb.WriteString(r.Content) - if toolCalls, ok := m.parseToolCalls(sb.String()); ok { - res.Message.ToolCalls = toolCalls - for i := range toolCalls { - toolCalls[i].Function.Index = toolCallIndex - toolCallIndex++ + if len(req.Tools) > 0 && checkToolCall { + slog.Debug("parse tool calls", "content", sb.String(), "templateToolToken", templateToolToken) + toolCalls, partial, err := ParseToolCalls(sb.String(), templateToolToken, tmpl) + if err == nil { + if partial { + // circuit break to remove tool end token + if len(toolCalls) > 0 { + sb.Reset() + } + // If the tool call is partial, we need to wait for the next chunk + return + } + res.Message.ToolCalls = toolCalls + for i := range toolCalls { + toolCalls[i].Function.Index = toolCallIndex + toolCallIndex++ + } + res.Message.Content = "" + sb.Reset() + ch <- res + // Only way to have multiple calls is to have [] which is derived or provided + if templateToolToken == "" { + checkToolCall = false + } + return } - res.Message.Content = "" - sb.Reset() - ch <- res - return } - if r.Done { - // Send any remaining content if no tool calls were detected - if toolCallIndex == 0 { - res.Message.Content = sb.String() - } - ch <- res + // If there is no template tool token, we don't need to check for tool calls after the first chunk + if templateToolToken == "" { + checkToolCall = false } + res.Message.Content = sb.String() + sb.Reset() + ch <- res }); err != nil { ch <- gin.H{"error": err.Error()} } @@ -1551,11 +1578,15 @@ func (s *Server) ChatHandler(c *gin.Context) { if req.Stream != nil && !*req.Stream { var resp api.ChatResponse var sb strings.Builder + var toolCalls []api.ToolCall for rr := range ch { switch t := rr.(type) { case api.ChatResponse: sb.WriteString(t.Message.Content) resp = t + if len(req.Tools) > 0 { + toolCalls = append(toolCalls, t.Message.ToolCalls...) + } case gin.H: msg, ok := t["error"].(string) if !ok { @@ -1571,12 +1602,8 @@ func (s *Server) ChatHandler(c *gin.Context) { } resp.Message.Content = sb.String() - - if len(req.Tools) > 0 { - if toolCalls, ok := m.parseToolCalls(sb.String()); ok { - resp.Message.ToolCalls = toolCalls - resp.Message.Content = "" - } + if len(toolCalls) > 0 { + resp.Message.ToolCalls = toolCalls } c.JSON(http.StatusOK, resp) diff --git a/server/testdata/tools/qwen2.5-coder.gotmpl b/server/testdata/tools/qwen2.5-coder.gotmpl new file mode 100644 index 000000000..cbd7302c4 --- /dev/null +++ b/server/testdata/tools/qwen2.5-coder.gotmpl @@ -0,0 +1,51 @@ +{{- if .Suffix }}<|fim_prefix|>{{ .Prompt }}<|fim_suffix|>{{ .Suffix }}<|fim_middle|> +{{- else if .Messages }} +{{- if or .System .Tools }}<|im_start|>system +{{- if .System }} +{{ .System }} +{{- end }} +{{- if .Tools }} + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{{- range .Tools }} +{"type": "function", "function": {{ .Function }}} +{{- end }} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } + +{{- end }}<|im_end|> +{{ end }} +{{- range $i, $_ := .Messages }} +{{- $last := eq (len (slice $.Messages $i)) 1 -}} +{{- if eq .Role "user" }}<|im_start|>user +{{ .Content }}<|im_end|> +{{ else if eq .Role "assistant" }}<|im_start|>assistant +{{ if .Content }}{{ .Content }} +{{- else if .ToolCalls }} +{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} +{{ end }} +{{- end }}{{ if not $last }}<|im_end|> +{{ end }} +{{- else if eq .Role "tool" }}<|im_start|>user + +{{ .Content }} +<|im_end|> +{{ end }} +{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant +{{ end }} +{{- end }} +{{- else }} +{{- if .System }}<|im_start|>system +{{ .System }}<|im_end|> +{{ end }}{{ if .Prompt }}<|im_start|>user +{{ .Prompt }}<|im_end|> +{{ end }}<|im_start|>assistant +{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }} \ No newline at end of file diff --git a/server/testdata/tools/qwen2.5-coder.out b/server/testdata/tools/qwen2.5-coder.out new file mode 100644 index 000000000..76bfbfa98 --- /dev/null +++ b/server/testdata/tools/qwen2.5-coder.out @@ -0,0 +1,31 @@ +<|im_start|>system +You are a knowledgeable assistant. You can answer questions and perform tasks. + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{"type": "function", "function": {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } +<|im_end|> +<|im_start|>user +What's the weather like today in Paris?<|im_end|> +<|im_start|>assistant + +{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}} +<|im_end|> +<|im_start|>user + +22 +<|im_end|> +<|im_start|>assistant +The current temperature in Paris, France is 22 degrees Celsius.<|im_end|> +<|im_start|>user +What's the weather like today in San Francisco and Toronto?<|im_end|> +<|im_start|>assistant diff --git a/server/tools.go b/server/tools.go new file mode 100644 index 000000000..aa2f729db --- /dev/null +++ b/server/tools.go @@ -0,0 +1,175 @@ +package server + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "strings" + gotmpl "text/template" + + "github.com/ollama/ollama/api" +) + +func parseObjects(s string) []map[string]any { + var objs []map[string]any + for offset := 0; offset < len(s); { + var obj map[string]any + decoder := json.NewDecoder(strings.NewReader(s[offset:])) + err := decoder.Decode(&obj) + switch { + case errors.Is(err, io.EOF), errors.Is(err, io.ErrUnexpectedEOF): + return objs + case err != nil: + var syntax *json.SyntaxError + var unmarshalType *json.UnmarshalTypeError + switch { + case errors.As(err, &syntax): + offset += int(syntax.Offset) + continue + case errors.As(err, &unmarshalType): + offset += int(unmarshalType.Offset) + continue + default: + return nil + } + } + offset += int(decoder.InputOffset()) + objs = append(objs, obj) + } + return objs +} + +// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls. +// Returns parsed tool calls and a boolean indicating if the JSON is incomplete +func parseJSONToolCalls(tmpl *gotmpl.Template, s string) ([]api.ToolCall, bool) { + var b bytes.Buffer + if err := tmpl.Execute(&b, map[string][]api.ToolCall{ + "ToolCalls": { + { + Function: api.ToolCallFunction{ + Name: "@@name@@", + Arguments: api.ToolCallFunctionArguments{ + "@@argument@@": 1, + }, + }, + }, + }, + }); err != nil { + return nil, false + } + + templateObjects := parseObjects(b.String()) + if len(templateObjects) == 0 { + return nil, false + } + + // find the keys that correspond to the name and arguments fields + var name, arguments string + for k, v := range templateObjects[0] { + switch v.(type) { + case string: + name = k + case map[string]any: + arguments = k + } + } + + if name == "" || arguments == "" { + return nil, false + } + + responseObjects := parseObjects(s) + if len(responseObjects) == 0 { + return nil, false + } + + // collect all nested objects + var collect func(any) []map[string]any + collect = func(obj any) (all []map[string]any) { + switch o := obj.(type) { + case map[string]any: + all = append(all, o) + for _, v := range o { + all = append(all, collect(v)...) + } + case []any: + for _, v := range o { + all = append(all, collect(v)...) + } + } + + return all + } + + var objs []map[string]any + for _, p := range responseObjects { + objs = append(objs, collect(p)...) + } + + var toolCalls []api.ToolCall + for _, kv := range objs { + n, nok := kv[name].(string) + a, aok := kv[arguments].(map[string]any) + if nok && aok { + toolCalls = append(toolCalls, api.ToolCall{ + Function: api.ToolCallFunction{ + Name: n, + Arguments: a, + }, + }) + } + } + + return toolCalls, len(toolCalls) > 0 +} + +// routeToolParsing is a helper function that routes what kind of tool parsing to use +func routeToolParsing(s string, tmpl *gotmpl.Template) ([]api.ToolCall, bool, bool) { + if strings.HasPrefix(s, "[{") || strings.HasPrefix(s, "```") || strings.HasPrefix(s, "{") { + if toolCalls, ok := parseJSONToolCalls(tmpl, s); ok { + return toolCalls, false, true + } + // in the case the JSON never finishes, the acuumulated content should be sent downstream + return nil, true, true + } + // TODO(parthsareen): add python tool call support + return nil, false, false +} + +// ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing. +// Returns tool calls, whether parsing is incomplete, and any errors. +func ParseToolCalls(s string, toolToken string, tmpl *gotmpl.Template) ([]api.ToolCall, bool, error) { + if tmpl == nil { + return nil, false, fmt.Errorf("no template provided") + } + s = strings.TrimSpace(s) + if len(s) == 0 { + return nil, false, fmt.Errorf("empty input string") + } + if toolToken != "" { + if strings.HasPrefix(s, toolToken) { + s = strings.TrimSpace(s[len(toolToken):]) + tc, _, ok := routeToolParsing(s, tmpl) + if len(tc) == 0 || !ok { + return nil, true, nil + } + return tc, false, nil + // Special token end case + } else if strings.HasSuffix(s, toolToken[2:]) { + tc := api.ToolCall{ + Function: api.ToolCallFunction{ + Name: toolToken, + }, + } + return []api.ToolCall{tc}, true, nil + } + } + + tc, partial, ok := routeToolParsing(s, tmpl) + if !ok { + return nil, false, fmt.Errorf("failed to parse tool calls for input: %q", s) + } + return tc, partial, nil +} diff --git a/server/tools_test.go b/server/tools_test.go new file mode 100644 index 000000000..9250e82ee --- /dev/null +++ b/server/tools_test.go @@ -0,0 +1,281 @@ +package server + +import ( + "bytes" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/template" +) + +func readFile(t *testing.T, base, name string) *bytes.Buffer { + t.Helper() + + bts, err := os.ReadFile(filepath.Join(base, name)) + if err != nil { + t.Fatal(err) + } + + return bytes.NewBuffer(bts) +} + +func TestParseToolCalls(t *testing.T) { + p := filepath.Join("testdata", "tools") + t1 := api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: api.ToolCallFunctionArguments{ + "format": "fahrenheit", + "location": "San Francisco, CA", + }, + }, + } + t2 := api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: api.ToolCallFunctionArguments{ + "format": "celsius", + "location": "Toronto, Canada", + }, + }, + } + + cases := []struct { + name string + model string + output string + token string + expected []api.ToolCall + wantErr bool + }{ + { + name: "mistral invalid json", + model: "mistral", + output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`, + token: "[TOOL_CALLS]", + expected: []api.ToolCall{}, + wantErr: true, + }, + { + name: "mistral valid json", + model: "mistral", + output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + token: "[TOOL_CALLS]", + expected: []api.ToolCall{t1, t2}, + wantErr: false, + }, + { + name: "mistral incomplete json", + model: "mistral", + output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `, + token: "[TOOL_CALLS]", + expected: []api.ToolCall{}, + wantErr: true, + }, + { + name: "mistral without tool token", + model: "mistral", + output: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: + + [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + token: "[TOOL_CALLS]", + expected: []api.ToolCall{}, + wantErr: true, + }, + { + name: "mistral without tool token - tool first", + model: "mistral", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + token: "[TOOL_CALLS]", + expected: []api.ToolCall{t1, t2}, + wantErr: false, + }, + { + name: "command-r-plus with json block", + model: "command-r-plus", + output: "Action: ```json" + ` + [ + { + "tool_name": "get_current_weather", + "parameters": { + "format": "fahrenheit", + "location": "San Francisco, CA" + } + }, + { + "tool_name": "get_current_weather", + "parameters": { + "format": "celsius", + "location": "Toronto, Canada" + } + } + ] + ` + "```", + token: "Action:", + expected: []api.ToolCall{t1, t2}, + wantErr: false, + }, + { + name: "firefunction with functools", + model: "firefunction", + output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + token: "functools", + expected: []api.ToolCall{t1, t2}, + wantErr: false, + }, + { + name: "llama3 with tool call tags", + model: "llama3-groq-tool-use", + output: ` + {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} + `, + token: "", + expected: []api.ToolCall{t1}, + wantErr: false, + }, + { + name: "xlam with tool_calls wrapper", + model: "xlam", + output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, + token: "", + expected: []api.ToolCall{t1, t2}, + wantErr: false, + }, + { + name: "qwen with single tool call", + model: "qwen2.5-coder", + output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}`, + token: "", + expected: []api.ToolCall{t1}, + wantErr: false, + }, + { + name: "qwen with invalid tool token", + model: "qwen2.5-coder", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + token: "[TOOL_CALLS]", + expected: []api.ToolCall{t1, t2}, + wantErr: false, + }, + { + name: "qwen with no tool calls", + model: "qwen2.5-coder", + output: " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", + token: "", + expected: []api.ToolCall{}, + wantErr: true, + }, + } + + var tools []api.Tool + if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil { + t.Fatal(err) + } + + var messages []api.Message + if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil { + t.Fatal(err) + } + + for _, tt := range cases { + t.Run(tt.model, func(t *testing.T) { + tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String()) + if err != nil { + t.Fatal(err) + } + + t.Run("template", func(t *testing.T) { + var actual bytes.Buffer + if err := tmpl.Execute(&actual, template.Values{Tools: tools, Messages: messages}); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + + t.Run("parse", func(t *testing.T) { + m := &Model{Template: tmpl} + tmpl, ok := ToolTemplate(m) + if !ok { + t.Fatal("no tool template found") + } + got := []api.ToolCall{} + tokens := strings.Fields(tt.output) + sb := strings.Builder{} + success := false + for _, tok := range tokens { + sb.WriteString(" " + tok) + toolCalls, partial, err := ParseToolCalls(sb.String(), tt.token, tmpl) + if err == nil { + success = true + } + if partial { + continue + } + got = append(got, toolCalls...) + sb.Reset() + } + + if !tt.wantErr { + if diff := cmp.Diff(got, tt.expected); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + } + if !success && !tt.wantErr { + t.Errorf("expected success but got errors") + } + }) + }) + } +} + +func TestParseObjects(t *testing.T) { + tests := []struct { + input string + want []map[string]any + }{ + { + input: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + want: []map[string]any{ + {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}}, + {"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, Canada"}}, + }, + }, + { + input: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + want: []map[string]any{ + {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}}, + }, + }, + { + input: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, ON"}} `, + want: []map[string]any{ + {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}}, + {"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, ON"}}, + }, + }, + { + input: `{"name": "get_current_weather", "arguments": `, + want: nil, + }, + } + + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + got := parseObjects(tc.input) + + if diff := cmp.Diff(got, tc.want); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + } +}