diff --git a/server/routes.go b/server/routes.go
index 64823bd32..c1868ff89 100644
--- a/server/routes.go
+++ b/server/routes.go
@@ -1529,9 +1529,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
if err == nil {
if len(content) > 0 {
res.Message.Content = content
- fmt.Println("sending content in response", content)
+ slog.Debug("tools: setting content to", "content", content)
} else if len(toolCalls) > 0 {
- fmt.Println("sending tool calls in response", toolCalls)
res.Message.ToolCalls = toolCalls
res.Message.Content = ""
} else {
diff --git a/tools/tools.go b/tools/tools.go
index 1105848e9..bf72cf212 100644
--- a/tools/tools.go
+++ b/tools/tools.go
@@ -1,9 +1,7 @@
package tools
import (
- "bytes"
"errors"
- "fmt"
"io"
"log/slog"
"strings"
@@ -16,136 +14,56 @@ import (
"github.com/ollama/ollama/template"
)
-// TODO: simplify if possible
type Parser struct {
- greedy bool
+ greedyParse bool
prefixFound bool
- partialPrefix bool
+ prefixPartial bool
tmpl *gotmpl.Template
sb *strings.Builder
prefix string
index int
+ name string
+ arguments string
Done bool
}
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
-// Returns parsed tool calls, a boolean indicating if the JSON is incomplete, and a boolean indicating if the tool calls were found
-func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
- fmt.Printf("attempting to parse JSON tool calls: input=%s\n", s)
-
- var b bytes.Buffer
- if err := p.tmpl.Execute(&b, map[string][]api.ToolCall{
- "ToolCalls": {
- {
- Function: api.ToolCallFunction{
- Name: "@@name@@",
- Arguments: api.ToolCallFunctionArguments{
- "@@argument@@": 1,
- },
- },
- },
- },
- }); err != nil {
- fmt.Printf("failed to execute template: error=%v\n", err)
- return nil, false, false
- }
-
- // this can be either a map or an array
- var temp any
- err := jsonv2.Unmarshal(b.Bytes(), &temp)
- if err != nil {
- fmt.Printf("failed to unmarshal template: error=%v\n", err)
- return nil, false, false
- }
-
- var collect func(any) []map[string]any
- collect = func(obj any) (all []map[string]any) {
- switch o := obj.(type) {
- case map[string]any:
- all = append(all, o)
- for _, v := range o {
- all = append(all, collect(v)...)
- }
- case []any:
- for _, v := range o {
- all = append(all, collect(v)...)
- }
- default:
- // TODO: err or fallback
- fmt.Printf("collect encountered unknown type: type=%T\n", obj)
- return nil
- }
-
- return all
- }
-
- var templateObjects []map[string]any
- switch t := temp.(type) {
- case map[string]any:
- templateObjects = []map[string]any{t}
- case []map[string]any:
- templateObjects = t
- // ! fallback?
- case []any:
- templateObjects = collect(t)
- }
- if len(templateObjects) == 0 {
- fmt.Println("no template objects found")
- return nil, false, false
- }
-
- // find the keys that correspond to the name and arguments fields
- var name, arguments string
- for k, v := range templateObjects[0] {
- switch v.(type) {
- case string:
- name = k
- fmt.Printf("found name field: key=%s\n", k)
- case map[string]any:
- arguments = k
- fmt.Printf("found arguments field: key=%s\n", k)
- }
- }
-
- if name == "" || arguments == "" {
- fmt.Printf("missing required fields: name_found=%v arguments_found=%v\n", name != "", arguments != "")
- return nil, false, false
- }
-
- // TODO: there is probably some underlying repeat work here to avoid
- // This incrementally decodes the JSON string and returns the first parsedobject
+// It first tries to incrementally decode the JSON to handle partial inputs.
+// Returns:
+// - []api.ToolCall: The parsed tool calls if successful
+// - bool: True if JSON is incomplete and needs more input
+func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool) {
+ // First try incremental decoding to handle partial JSON
dec := jsontext.NewDecoder(strings.NewReader(s))
if got, err := dec.ReadValue(); err == nil {
s = got.String()
- fmt.Printf("decoded JSON value: value=%s\n", s)
}
- var responseObjects any
- err = jsonv2.Unmarshal([]byte(s), &responseObjects)
+ // Attempt full unmarshal of the JSON
+ var resp any
+ err := jsonv2.Unmarshal([]byte(s), &resp)
if err != nil {
+ // Handle incomplete JSON cases
if errors.Is(err, io.ErrUnexpectedEOF) || err.Error() == "unexpected end of JSON input" {
- fmt.Println("incomplete JSON detected")
- return nil, true, false
- } else {
- fmt.Printf("failed to unmarshal response: error=%v\n", err)
- return nil, false, false
+ slog.Debug("incomplete JSON detected", "input", s)
+ return nil, true
}
+ slog.Debug("failed to unmarshal response", "error", err)
+ return nil, false
}
+ // Collect all nested objects that could contain tool calls
var objs []map[string]any
- objs = append(objs, collect(responseObjects)...)
+ objs = append(objs, collect(resp)...)
if len(objs) == 0 {
- return nil, false, false
+ return nil, false
}
- fmt.Printf("collected objects: count=%d\n", len(objs))
-
var toolCalls []api.ToolCall
for _, kv := range objs {
- n, nok := kv[name].(string)
- a, aok := kv[arguments].(map[string]any)
+ n, nok := kv[p.name].(string)
+ a, aok := kv[p.arguments].(map[string]any)
if nok && aok {
- fmt.Printf("found valid tool call: name=%s\n", n)
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: n,
@@ -155,130 +73,170 @@ func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
}
}
- fmt.Printf("parsed tool calls: count=%d\n", len(toolCalls))
- return toolCalls, false, true
+ // Valid JSON, no tool calls found
+ if len(toolCalls) == 0 {
+ return nil, false
+ }
+
+ return toolCalls, false
}
-// prefix stripped string if any, prefix found, and if we should accumulate
+// checkPrefix processes a string to find and handle a prefix pattern.
+//
+// Returns:
+// - The processed string with prefix removed if found
+// - Whether the prefix was found at the start of the string
+// - Whether to continue parsing
func (p *Parser) checkPrefix(s string) (string, bool, bool) {
+ // Keep original for overlap checks
+ original := s
+ s = strings.TrimSpace(s)
+ if s == "" {
+ return "", false, true
+ }
+ // If no prefix defined, just return trimmed string
if p.prefix == "" {
return s, false, true
}
- original := s
- s = strings.TrimSpace(s)
- s, hasPrefix := strings.CutPrefix(s, p.prefix)
- if hasPrefix {
- // partial tool possibly - accumulate
- return s, true, true
- } else if overlap := suffixOverlap(original, p.prefix); overlap > 0 {
- // p.state = PartialPrefix
- p.partialPrefix = true
- return original[0 : len(original)-overlap], false, false
- } else if idx := strings.Index(original, p.prefix); idx != -1 {
- // Found prefix in middle of string, keep only content before prefix
- // accounts for spaces in prefix or suffix to avoid breaking cache
- p.partialPrefix = true
- p.sb.Reset()
+ // Check for prefix at start of string
+ if processedStr, hasPrefix := strings.CutPrefix(s, p.prefix); hasPrefix {
+ // Found prefix at start - accumulate for potential tool
+ return processedStr, true, true
+ }
+
+ // Check if prefix overlaps end of string
+ if overlap := suffixOverlap(original, p.prefix); overlap > 0 {
+ p.prefixPartial = true
+ // Return everything except overlapping portion
+ p.sb.Reset()
+ p.sb.WriteString(original[len(original)-overlap:])
+ return original[0 : len(original)-overlap], false, false
+ }
+
+ // Check if prefix appears in middle of string
+ if idx := strings.Index(original, p.prefix); idx != -1 {
+ p.prefixPartial = true
+ // Save remainder starting at prefix for next pass
+ p.sb.Reset()
p.sb.WriteString(strings.TrimSpace(original[idx:]))
+ // Return everything before prefix
return original[:idx], false, false
}
- p.partialPrefix = false
+ // No prefix found
+ p.prefixPartial = false
return s, false, true
}
+// Add processes a string input to parse tool calls and content.
+// It handles prefix detection and JSON parsing to extract tool calls.
+//
+// Returns:
+// - tools: Any parsed tool calls
+// - content: Non-tool call content
+// - err: Error if parsing failed
func (p *Parser) Add(s string) (tools []api.ToolCall, content string, err error) {
- slog.Debug("adding tool calls", "input", s)
-
p.sb.WriteString(s)
s = p.sb.String()
-
if len(s) == 0 {
return nil, "", nil
}
- s, prefixFound, cont := p.checkPrefix(s)
-
- if !cont {
+ // Check for prefix pattern in input
+ s, prefixFound, shouldContinue := p.checkPrefix(s)
+ if !shouldContinue {
if s != "" {
- // send only the content back, prefix exists
+ // Return content before prefix
return nil, s, nil
}
- // accumulate case
+ // Need more input to complete prefix
return nil, "", nil
}
- // circuit breaker
+ // Update prefix found state
if prefixFound {
p.prefixFound = true
}
- // for cases with a prefix in template
- if p.prefix != "" && !p.greedy && !p.prefixFound {
- // send tokens down
+ // Exit if prefix exists in template, greedy parsing is off, and prefix not found
+ if !p.greedyParse && !p.prefixFound {
p.sb.Reset()
return nil, "", errors.New("prefix not found")
}
- // we have a prefix or are in json mode
- tcs, partial, ok := p.parseJSONToolCalls(s)
- if partial {
- // accumulate case
+
+ toolCalls, isPartial := p.parseJSONToolCalls(s)
+ if isPartial {
+ // Need more input to complete JSON
return nil, "", nil
}
- p.greedy = false
- if !ok {
- // will not be a partial at this point
+ // Do not try greedy parsing if partial JSON not found
+ p.greedyParse = false
+
+ // Handle invalid tool call format
+ if len(toolCalls) == 0 {
p.sb.Reset()
- // send tokens
if p.prefix == "" {
p.Done = true
}
if p.prefixFound {
- // drop tokens instead - sb is reset, no tokens sent to user
+ // Drop tokens since prefix was found
return nil, "", nil
}
- return nil, "", errors.New("failed to parse tool calls")
+ return nil, s, nil
}
- for _, tc := range tcs {
+ for _, tc := range toolCalls {
tc.Function.Index = p.index
p.index++
}
+
+ // Mark as done if no prefix needed
if p.prefix == "" {
p.Done = true
}
+
p.sb.Reset()
- return tcs, "", nil
+ return toolCalls, "", nil
}
+// NewParser creates a new tool call parser from a template. It extracts the tool call format,
+// prefix, and field names from the template to use for parsing tool calls from model output.
+//
+// Returns an error if the template does not contain valid tool call formatting.
func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) {
- parsedTemplate, err := template.Parse(templateToProcess.Root.String())
+ parsed, err := template.Parse(templateToProcess.Root.String())
if err != nil {
return nil, err
}
- if parsedTemplate == nil {
+ if parsed == nil {
return nil, errors.New("failed to parse template")
}
- toolCallTemplate, hasToolCalls := toolTemplate(parsedTemplate)
- if !hasToolCalls {
- return nil, errors.New("failed to find tool template")
+ tt, tc := toolTemplate(parsed)
+ if !tc {
+ return nil, errors.New("failed to find tool calls in template")
}
- if toolCallTemplate == nil {
+ if tt == nil {
return nil, errors.New("failed to find tool template")
}
- toolPrefix, _ := ToolPrefix(templateToProcess)
- toolPrefix = strings.TrimSpace(toolPrefix)
+ tp := toolPrefix(templateToProcess)
+ tp = strings.TrimSpace(tp)
+
+ name, arguments, err := extractToolArgs(tt)
+ if err != nil {
+ return nil, err
+ }
- fmt.Printf("creating new tool parser: prefix=%s\n", toolPrefix)
return &Parser{
- tmpl: toolCallTemplate,
- sb: &strings.Builder{},
- prefix: toolPrefix,
- greedy: true,
+ tmpl: tt,
+ sb: &strings.Builder{},
+ prefix: tp,
+ greedyParse: true,
+ name: name,
+ arguments: arguments,
}, nil
}
diff --git a/tools/tools_test.go b/tools/tools_test.go
index 597d0bdd1..bc436f838 100644
--- a/tools/tools_test.go
+++ b/tools/tools_test.go
@@ -6,11 +6,8 @@ import (
"fmt"
"os"
"path/filepath"
- "slices"
"strings"
"testing"
- gotmpl "text/template"
- "text/template/parse"
"github.com/google/go-cmp/cmp"
@@ -206,6 +203,27 @@ func TestParseToolCalls(t *testing.T) {
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
},
+ {
+ name: "qwen2.5 tool calls without prefix and valid tool call",
+ model: "qwen2.5-coder",
+ output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
+ expectedToolCall: []api.ToolCall{t1, t2},
+ expectedTokens: "",
+ },
+ {
+ name: "qwen2.5 tool calls without prefix and invalid tool call",
+ model: "qwen2.5-coder",
+ output: `[{"options": "foo"}]`,
+ expectedToolCall: []api.ToolCall{},
+ expectedTokens: `[{"options": "foo"}]`,
+ },
+ {
+ name: "qwen2.5 tool calls with prefix and invalid tool call",
+ model: "qwen2.5-coder",
+ output: ` [{"options": "foo"}] `,
+ expectedToolCall: []api.ToolCall{},
+ expectedTokens: ``,
+ },
{
name: "qwen3 tool call with think prefix and tool prefix (sent as a single token)",
model: "qwen3",
@@ -239,14 +257,14 @@ func TestParseToolCalls(t *testing.T) {
model: "qwen3",
output: `< fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `,
expectedToolCall: []api.ToolCall{},
- expectedTokens: ` fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `,
+ expectedTokens: `< fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `,
},
{
name: "qwen3 invalid tool call with partial tool prefix (multiple rune suffix match)",
model: "qwen3",
output: ``,
expectedToolCall: []api.ToolCall{},
- expectedTokens: ` fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `,
+ expectedTokens: ``,
},
{
name: "qwen3 invalid tool call with malformed tool prefix",
@@ -332,6 +350,7 @@ func TestParseToolCalls(t *testing.T) {
toolCalls, content, err := tp.Add(s)
if err == nil {
if content != "" {
+ fmt.Printf("content: %q\n", content)
gotTokens.WriteString(content)
add = false
} else if len(toolCalls) > 0 {
@@ -363,24 +382,101 @@ func TestParseToolCalls(t *testing.T) {
}
}
-func toolTemplateHelper(t *testing.T, tmpl *template.Template) (*gotmpl.Template, bool) {
- // create a subtree from the node that ranges over .ToolCalls
-
- tmpl2 := tmpl.Subtree(func(n parse.Node) bool {
- if t, ok := n.(*parse.RangeNode); ok {
- return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
- }
-
- return false
- })
-
- if tmpl2.Root != nil {
- t.Log("tmpl2", tmpl2.Root.String())
+func TestParseJSONToolCalls(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ parser *Parser
+ wantToolCalls []api.ToolCall
+ wantPartial bool
+ wantValid bool
+ }{
+ {
+ name: "valid single tool call",
+ input: `{"name": "test_tool", "arguments": {"arg1": "value1"}}`,
+ parser: &Parser{name: "name", arguments: "arguments"},
+ wantToolCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "test_tool",
+ Arguments: map[string]any{
+ "arg1": "value1",
+ },
+ },
+ },
+ },
+ wantPartial: false,
+ wantValid: true,
+ },
+ {
+ name: "incomplete JSON",
+ input: `{"name": "test_tool", "arguments": {"arg1": `,
+ parser: &Parser{name: "name", arguments: "arguments"},
+ wantToolCalls: nil,
+ wantPartial: true,
+ wantValid: false,
+ },
+ {
+ name: "invalid JSON",
+ input: `not json at all`,
+ parser: &Parser{name: "name", arguments: "arguments"},
+ wantToolCalls: nil,
+ wantPartial: false,
+ wantValid: false,
+ },
+ {
+ name: "missing required fields",
+ input: `{"other": "field"}`,
+ parser: &Parser{name: "name", arguments: "arguments"},
+ wantToolCalls: nil,
+ wantPartial: false,
+ wantValid: false,
+ },
+ {
+ name: "multiple tool calls in array",
+ input: `[
+ {"name": "tool1", "arguments": {"arg1": 1}},
+ {"name": "tool2", "arguments": {"arg2": "value"}}
+ ]`,
+ parser: &Parser{name: "name", arguments: "arguments"},
+ wantToolCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "tool1",
+ Arguments: map[string]any{
+ "arg1": float64(1),
+ },
+ },
+ },
+ {
+ Function: api.ToolCallFunction{
+ Name: "tool2",
+ Arguments: map[string]any{
+ "arg2": "value",
+ },
+ },
+ },
+ },
+ wantPartial: false,
+ wantValid: true,
+ },
}
- if tmpl2 == nil {
- return nil, false
- }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ gotCalls, gotPartial := tt.parser.parseJSONToolCalls(tt.input)
- return tmpl2, true
+ if gotPartial != tt.wantPartial {
+ t.Errorf("parseJSONToolCalls() partial = %v, want %v", gotPartial, tt.wantPartial)
+ }
+
+ if len(gotCalls) != 0 != tt.wantValid {
+ t.Errorf("parseJSONToolCalls() valid = %v, want %v", len(gotCalls) == 0, tt.wantValid)
+ }
+
+ if diff := cmp.Diff(gotCalls, tt.wantToolCalls); diff != "" {
+ t.Errorf("parseJSONToolCalls() tool calls mismatch (-got +want):\n%s", diff)
+ }
+ })
+ }
}
diff --git a/tools/utils.go b/tools/utils.go
index 6e37b9010..64f88658c 100644
--- a/tools/utils.go
+++ b/tools/utils.go
@@ -1,17 +1,27 @@
package tools
import (
+ "bytes"
+ "errors"
"log/slog"
"slices"
"strings"
gotmpl "text/template"
"text/template/parse"
+ jsonv2 "github.com/go-json-experiment/json"
+ "github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
)
-// extractToolCallsTemplate finds the immediate following text after any IfNode containing ".ToolCalls"
-func extractToolCallsTemplate(tmpl *gotmpl.Template) (string, bool) {
+// extractToolCallsFormat traverses a template AST to find text that follows a ".ToolCalls" condition.
+// It walks the template nodes looking for if-statements containing ".ToolCalls" and extracts any
+// immediate text nodes that follow. This is used to identify tool call prefixes and formatting.
+//
+// Returns:
+// - string: The extracted text following the first ".ToolCalls" condition found
+// - bool: Whether a ".ToolCalls" condition was found in the template
+func extractToolCallsFormat(tmpl *gotmpl.Template) (string, bool) {
if tmpl == nil || tmpl.Tree == nil {
slog.Debug("TextAfterToolCalls: template or tree is nil")
return "", false
@@ -29,7 +39,7 @@ func extractToolCallsTemplate(tmpl *gotmpl.Template) (string, bool) {
switch n := node.(type) {
case *parse.IfNode:
- if nodeContainsToolCalls(n) {
+ if isToolCallsNode(n) {
// Collect immediate TextNode(s) at start of IfNode's list
var sb strings.Builder
for _, innerNode := range n.List.Nodes {
@@ -76,8 +86,8 @@ func extractToolCallsTemplate(tmpl *gotmpl.Template) (string, bool) {
return result, found
}
-// Helper to detect if a node's condition includes ".ToolCalls"
-func nodeContainsToolCalls(n *parse.IfNode) bool {
+// isToolCallsNode detects if a node's condition includes ".ToolCalls"
+func isToolCallsNode(n *parse.IfNode) bool {
for _, cmd := range n.Pipe.Cmds {
for _, arg := range cmd.Args {
if field, ok := arg.(*parse.FieldNode); ok {
@@ -90,16 +100,17 @@ func nodeContainsToolCalls(n *parse.IfNode) bool {
return false
}
-// ToolPrefix returns the prefix for the tool call if it exists
// TODO(parthsareen): get full prefix from the template instead of just the first token
-func ToolPrefix(tmpl *gotmpl.Template) (string, bool) {
- tokenText, ok := extractToolCallsTemplate(tmpl)
+
+// toolPrefix returns the prefix for the tool call if it exists from a template
+func toolPrefix(tmpl *gotmpl.Template) string {
+ tokenText, ok := extractToolCallsFormat(tmpl)
if !ok {
- return "", false
+ return ""
}
tokenText = strings.TrimSpace(tokenText)
if tokenText == "" {
- return "", false
+ return ""
}
first := strings.Fields(tokenText)[0]
@@ -116,19 +127,23 @@ func ToolPrefix(tmpl *gotmpl.Template) (string, bool) {
}
if start != -1 && end != -1 {
// return the token including the [ or < and the ] or >
- return tokenText[start : end+1], true
+ return tokenText[start : end+1]
} else if start != -1 {
// get until the [ or < - in the case tag was not closed
- return tokenText[:start], true
+ return tokenText[:start]
} else if end != -1 {
// get after the ] or > - in the case tag was not opened
- return tokenText[end+1:], true
+ return tokenText[end+1:]
}
- return first, true
+ return first
}
+// toolTemplate creates a subtree from the node that ranges over .ToolCalls
+//
+// Returns:
+// - *gotmpl.Template: The subtree containing the .ToolCalls range
+// - bool: Whether a .ToolCalls range was found in the template
func toolTemplate(t *template.Template) (*gotmpl.Template, bool) {
- // create a subtree from the node that ranges over .ToolCalls
tmpl := t.Subtree(func(n parse.Node) bool {
if t, ok := n.(*parse.RangeNode); ok {
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
@@ -144,6 +159,10 @@ func toolTemplate(t *template.Template) (*gotmpl.Template, bool) {
return tmpl, true
}
+// suffixOverlap returns the length of the longest suffix overlap between two strings
+//
+// Returns:
+// - int: The length of the longest suffix overlap
func suffixOverlap(s, delim string) int {
max := min(len(delim), len(s))
for i := max; i > 0; i-- {
@@ -153,3 +172,86 @@ func suffixOverlap(s, delim string) int {
}
return 0
}
+
+// extractToolArgs executes a template with a known tool call format to extract the name and arguments
+//
+// Returns:
+// - string: The name of the tool call
+// - string: The arguments of the tool call
+// - error: Error if parsing failed
+func extractToolArgs(tmpl *gotmpl.Template) (name, arguments string, err error) {
+ var b bytes.Buffer
+ if err := tmpl.Execute(&b, map[string][]api.ToolCall{
+ "ToolCalls": {
+ {
+ Function: api.ToolCallFunction{
+ Name: "@@name@@",
+ Arguments: api.ToolCallFunctionArguments{
+ "@@argument@@": 1,
+ },
+ },
+ },
+ },
+ }); err != nil {
+ return "", "", err
+ }
+
+ var obj any
+ err = jsonv2.Unmarshal(b.Bytes(), &obj)
+ if err != nil {
+ return "", "", err
+ }
+
+ var objs []map[string]any
+ switch v := obj.(type) {
+ case map[string]any:
+ objs = []map[string]any{v}
+ case []map[string]any:
+ objs = v
+ case []any:
+ objs = collect(v)
+ }
+ if len(objs) == 0 {
+ return "", "", errors.New("no template objects found")
+ }
+
+ // find the keys that correspond to the name and arguments fields
+ for k, v := range objs[0] {
+ switch v.(type) {
+ case string:
+ name = k
+ case map[string]any:
+ arguments = k
+ }
+ }
+
+ if name == "" || arguments == "" {
+ slog.Debug("missing required fields in tool call template", "name", name, "arguments", arguments)
+ return "", "", errors.New("missing required fields in tool call template")
+ }
+
+ return name, arguments, nil
+}
+
+// collect recursively traverses an object to collect all nested maps
+//
+// Returns:
+// - []map[string]any: A slice of all nested maps found in the object
+func collect(obj any) []map[string]any {
+ var all []map[string]any
+ switch o := obj.(type) {
+ case map[string]any:
+ all = append(all, o)
+ for _, v := range o {
+ all = append(all, collect(v)...)
+ }
+ case []any:
+ for _, v := range o {
+ all = append(all, collect(v)...)
+ }
+ default:
+ return nil
+ }
+
+ return all
+}
diff --git a/tools/utils_test.go b/tools/utils_test.go
index 4c37ecd40..c082fde02 100644
--- a/tools/utils_test.go
+++ b/tools/utils_test.go
@@ -3,74 +3,133 @@ package tools
import (
"testing"
gotmpl "text/template"
+
+ "github.com/ollama/ollama/template"
)
+func TestExtractToolCallsFormat(t *testing.T) {
+ cases := []struct {
+ name string
+ template string
+ want string
+ found bool
+ }{
+ {
+ name: "nil template",
+ template: "",
+ want: "",
+ found: false,
+ },
+ {
+ name: "basic tool call with text",
+ template: "{{if .ToolCalls}}Hello world{{end}}",
+ want: "Hello world",
+ found: true,
+ },
+ {
+ name: "tool call with json format",
+ template: "{{if .ToolCalls}}```json\n{{end}}",
+ want: "```json\n",
+ found: true,
+ },
+ {
+ name: "tool call in range",
+ template: "{{range .ToolCalls}}tool: {{.}}{{end}}",
+ want: "",
+ found: false,
+ },
+ {
+ name: "tool call with multiple text nodes",
+ template: "{{if .ToolCalls}}First text{{if .Something}}inner{{end}}Second text{{end}}",
+ want: "First text",
+ found: true,
+ },
+ {
+ name: "nested if without tool calls",
+ template: "{{if .Something}}{{if .OtherThing}}text{{end}}{{end}}",
+ want: "",
+ found: false,
+ },
+ }
+
+ for _, tc := range cases {
+ t.Run(tc.name, func(t *testing.T) {
+ tmpl, err := gotmpl.New("test").Parse(tc.template)
+ if err != nil && tc.template != "" {
+ t.Fatalf("failed to parse template: %v", err)
+ }
+
+ got, found := extractToolCallsFormat(tmpl)
+ if got != tc.want {
+ t.Errorf("got text %q, want %q", got, tc.want)
+ }
+ if found != tc.found {
+ t.Errorf("got found %v, want %v", found, tc.found)
+ }
+ })
+ }
+}
+
func TestToolPrefix(t *testing.T) {
cases := []struct {
name string
template string
want string
- ok bool
}{
{
name: "basic tool call with action prefix",
template: "{{if .ToolCalls}}Action: ```json{{end}}",
want: "Action:",
- ok: true,
},
{
name: "incomplete functools bracket",
template: "{{if .ToolCalls}}functools[{{end}}",
want: "functools",
- ok: true,
},
{
name: "tool call with angle brackets",
template: "{{if .ToolCalls}}Hello, world! {{end}}",
want: "",
- ok: true,
},
{
name: "multiple tool call formats",
template: "{{if .ToolCalls}}[tool_call] {{end}}",
want: "[tool_call]",
- ok: true,
},
{
name: "single angle bracket tool call",
template: "{{if .ToolCalls}}{{end}}",
want: "",
- ok: true,
},
{
name: "incomplete angle bracket after tool call",
template: "{{if .ToolCalls}}[tool_call] <{{end}}",
want: "[tool_call]",
- ok: true,
},
{
name: "angle bracket prefix with tool call",
template: "{{if .ToolCalls}}> {{end}}",
want: "",
- ok: true,
},
{
name: "uppercase tool call with incomplete bracket",
template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}",
want: "[TOOL_CALL]",
- ok: true,
},
{
name: "uppercase tool call with adjacent bracket",
template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}",
want: "[TOOL_CALL]",
- ok: true,
},
{
name: "tool call with pipe delimiters",
template: "{{if .ToolCalls}}<|tool_call|>{{end}}",
want: "<|tool_call|>",
- ok: true,
+ },
+ {
+ name: "tool with no prefix",
+ template: "{{if .ToolCalls}}{{end}}",
+ want: "",
},
}
@@ -80,18 +139,135 @@ func TestToolPrefix(t *testing.T) {
if err != nil {
t.Fatalf("failed to parse template: %v", err)
}
- got, ok := ToolPrefix(tmpl)
+ got := toolPrefix(tmpl)
if got != tt.want {
t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want)
}
- if ok != tt.ok {
- t.Errorf("ToolToken(%q) = %v; want %v", tt.template, ok, tt.ok)
- }
})
}
}
-func TestTextAfterToolCalls(t *testing.T) {
+func TestToolTemplate(t *testing.T) {
+ cases := []struct {
+ name string
+ template string
+ want bool
+ }{
+ {
+ name: "basic tool call range",
+ template: "{{range .ToolCalls}}test{{end}}",
+ want: true,
+ },
+ {
+ name: "no tool calls",
+ template: "{{range .Other}}test{{end}}",
+ want: false,
+ },
+ {
+ name: "nested tool calls",
+ template: "{{range .Outer}}{{range .ToolCalls}}test{{end}}{{end}}",
+ want: true,
+ },
+ {
+ name: "empty template",
+ template: "",
+ want: false,
+ },
+ {
+ name: "tool calls in if statement",
+ template: "{{if .ToolCalls}}test{{end}}",
+ want: false,
+ },
+ }
+
+ for _, tt := range cases {
+ t.Run(tt.name, func(t *testing.T) {
+ tmpl, err := gotmpl.New("test").Parse(tt.template)
+ if err != nil {
+ t.Fatalf("failed to parse template: %v", err)
+ }
+
+ parsed, err := template.Parse(tmpl.Root.String())
+ if err != nil {
+ t.Fatalf("failed to parse template: %v", err)
+ }
+
+ _, got := toolTemplate(parsed)
+ if got != tt.want {
+ t.Errorf("toolTemplate() = %v; want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestSuffixOverlap(t *testing.T) {
+ cases := []struct {
+ name string
+ s string
+ d string
+ want int
+ }{
+ {
+ name: "no overlap",
+ s: "hello world",
+ d: "",
+ want: 0,
+ },
+ {
+ name: "full overlap",
+ s: "",
+ d: "",
+ want: 11,
+ },
+ {
+ name: "partial overlap",
+ s: "text ",
+ d: "",
+ want: 11,
+ },
+ {
+ name: "delimiter longer than string",
+ s: "",
+ d: "",
+ want: 0,
+ },
+ {
+ name: "empty string",
+ s: "",
+ d: "",
+ want: 0,
+ },
+ {
+ name: "empty delimiter",
+ s: "",
+ d: "",
+ want: 0,
+ },
+ {
+ name: "single char overlap",
+ s: "test<",
+ d: "",
+ want: 1,
+ },
+ {
+ name: "partial tool call",
+ s: "hello ",
+ want: 6,
+ },
+ }
+
+ for _, tt := range cases {
+ t.Run(tt.name, func(t *testing.T) {
+ got := suffixOverlap(tt.s, tt.d)
+ if got != tt.want {
+ t.Errorf("suffixOverlap(%q, %q) = %d; want %d", tt.s, tt.d, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestExtractToolArgs(t *testing.T) {
cases := []struct {
name string
template string
@@ -173,7 +349,7 @@ func TestTextAfterToolCalls(t *testing.T) {
t.Fatalf("failed to parse template: %v", err)
}
- got, ok := extractToolCallsTemplate(tmpl)
+ got, ok := extractToolCallsFormat(tmpl)
if got != tt.want {
t.Errorf("TextAfterToolCalls() got = %q, want %q", got, tt.want)
}
@@ -183,3 +359,106 @@ func TestTextAfterToolCalls(t *testing.T) {
})
}
}
+
+func TestCollect(t *testing.T) {
+ cases := []struct {
+ name string
+ obj any
+ want []map[string]any
+ }{
+ {
+ name: "simple map",
+ obj: map[string]any{
+ "key": "value",
+ },
+ want: []map[string]any{
+ {"key": "value"},
+ },
+ },
+ {
+ name: "nested map",
+ obj: map[string]any{
+ "outer": map[string]any{
+ "inner": "value",
+ },
+ },
+ want: []map[string]any{
+ {"outer": map[string]any{"inner": "value"}},
+ {"inner": "value"},
+ },
+ },
+ {
+ name: "array of maps",
+ obj: []any{
+ map[string]any{"key1": "val1"},
+ map[string]any{"key2": "val2"},
+ },
+ want: []map[string]any{
+ {"key1": "val1"},
+ {"key2": "val2"},
+ },
+ },
+ {
+ name: "deeply nested",
+ obj: map[string]any{
+ "l1": map[string]any{
+ "l2": map[string]any{
+ "l3": "value",
+ },
+ },
+ },
+ want: []map[string]any{
+ {"l1": map[string]any{"l2": map[string]any{"l3": "value"}}},
+ {"l2": map[string]any{"l3": "value"}},
+ {"l3": "value"},
+ },
+ },
+ {
+ name: "non-map value",
+ obj: "string",
+ want: nil,
+ },
+ }
+
+ for _, tt := range cases {
+ t.Run(tt.name, func(t *testing.T) {
+ got := collect(tt.obj)
+ if len(got) != len(tt.want) {
+ t.Errorf("collect() got %d maps, want %d", len(got), len(tt.want))
+ return
+ }
+
+ // Compare each map in the result
+ for i := range tt.want {
+ if !mapsEqual(got[i], tt.want[i]) {
+ t.Errorf("collect() map[%d] = %v, want %v", i, got[i], tt.want[i])
+ }
+ }
+ })
+ }
+}
+
+// mapsEqual compares two maps for deep equality
+func mapsEqual(m1, m2 map[string]any) bool {
+ if len(m1) != len(m2) {
+ return false
+ }
+ for k, v1 := range m1 {
+ v2, ok := m2[k]
+ if !ok {
+ return false
+ }
+ switch val1 := v1.(type) {
+ case map[string]any:
+ val2, ok := v2.(map[string]any)
+ if !ok || !mapsEqual(val1, val2) {
+ return false
+ }
+ default:
+ if v1 != v2 {
+ return false
+ }
+ }
+ }
+ return true
+}