diff --git a/server/routes.go b/server/routes.go index 64823bd32..c1868ff89 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1529,9 +1529,8 @@ func (s *Server) ChatHandler(c *gin.Context) { if err == nil { if len(content) > 0 { res.Message.Content = content - fmt.Println("sending content in response", content) + slog.Debug("tools: setting content to", "content", content) } else if len(toolCalls) > 0 { - fmt.Println("sending tool calls in response", toolCalls) res.Message.ToolCalls = toolCalls res.Message.Content = "" } else { diff --git a/tools/tools.go b/tools/tools.go index 1105848e9..bf72cf212 100644 --- a/tools/tools.go +++ b/tools/tools.go @@ -1,9 +1,7 @@ package tools import ( - "bytes" "errors" - "fmt" "io" "log/slog" "strings" @@ -16,136 +14,56 @@ import ( "github.com/ollama/ollama/template" ) -// TODO: simplify if possible type Parser struct { - greedy bool + greedyParse bool prefixFound bool - partialPrefix bool + prefixPartial bool tmpl *gotmpl.Template sb *strings.Builder prefix string index int + name string + arguments string Done bool } // parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls. -// Returns parsed tool calls, a boolean indicating if the JSON is incomplete, and a boolean indicating if the tool calls were found -func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) { - fmt.Printf("attempting to parse JSON tool calls: input=%s\n", s) - - var b bytes.Buffer - if err := p.tmpl.Execute(&b, map[string][]api.ToolCall{ - "ToolCalls": { - { - Function: api.ToolCallFunction{ - Name: "@@name@@", - Arguments: api.ToolCallFunctionArguments{ - "@@argument@@": 1, - }, - }, - }, - }, - }); err != nil { - fmt.Printf("failed to execute template: error=%v\n", err) - return nil, false, false - } - - // this can be either a map or an array - var temp any - err := jsonv2.Unmarshal(b.Bytes(), &temp) - if err != nil { - fmt.Printf("failed to unmarshal template: error=%v\n", err) - return nil, false, false - } - - var collect func(any) []map[string]any - collect = func(obj any) (all []map[string]any) { - switch o := obj.(type) { - case map[string]any: - all = append(all, o) - for _, v := range o { - all = append(all, collect(v)...) - } - case []any: - for _, v := range o { - all = append(all, collect(v)...) - } - default: - // TODO: err or fallback - fmt.Printf("collect encountered unknown type: type=%T\n", obj) - return nil - } - - return all - } - - var templateObjects []map[string]any - switch t := temp.(type) { - case map[string]any: - templateObjects = []map[string]any{t} - case []map[string]any: - templateObjects = t - // ! fallback? - case []any: - templateObjects = collect(t) - } - if len(templateObjects) == 0 { - fmt.Println("no template objects found") - return nil, false, false - } - - // find the keys that correspond to the name and arguments fields - var name, arguments string - for k, v := range templateObjects[0] { - switch v.(type) { - case string: - name = k - fmt.Printf("found name field: key=%s\n", k) - case map[string]any: - arguments = k - fmt.Printf("found arguments field: key=%s\n", k) - } - } - - if name == "" || arguments == "" { - fmt.Printf("missing required fields: name_found=%v arguments_found=%v\n", name != "", arguments != "") - return nil, false, false - } - - // TODO: there is probably some underlying repeat work here to avoid - // This incrementally decodes the JSON string and returns the first parsedobject +// It first tries to incrementally decode the JSON to handle partial inputs. +// Returns: +// - []api.ToolCall: The parsed tool calls if successful +// - bool: True if JSON is incomplete and needs more input +func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool) { + // First try incremental decoding to handle partial JSON dec := jsontext.NewDecoder(strings.NewReader(s)) if got, err := dec.ReadValue(); err == nil { s = got.String() - fmt.Printf("decoded JSON value: value=%s\n", s) } - var responseObjects any - err = jsonv2.Unmarshal([]byte(s), &responseObjects) + // Attempt full unmarshal of the JSON + var resp any + err := jsonv2.Unmarshal([]byte(s), &resp) if err != nil { + // Handle incomplete JSON cases if errors.Is(err, io.ErrUnexpectedEOF) || err.Error() == "unexpected end of JSON input" { - fmt.Println("incomplete JSON detected") - return nil, true, false - } else { - fmt.Printf("failed to unmarshal response: error=%v\n", err) - return nil, false, false + slog.Debug("incomplete JSON detected", "input", s) + return nil, true } + slog.Debug("failed to unmarshal response", "error", err) + return nil, false } + // Collect all nested objects that could contain tool calls var objs []map[string]any - objs = append(objs, collect(responseObjects)...) + objs = append(objs, collect(resp)...) if len(objs) == 0 { - return nil, false, false + return nil, false } - fmt.Printf("collected objects: count=%d\n", len(objs)) - var toolCalls []api.ToolCall for _, kv := range objs { - n, nok := kv[name].(string) - a, aok := kv[arguments].(map[string]any) + n, nok := kv[p.name].(string) + a, aok := kv[p.arguments].(map[string]any) if nok && aok { - fmt.Printf("found valid tool call: name=%s\n", n) toolCalls = append(toolCalls, api.ToolCall{ Function: api.ToolCallFunction{ Name: n, @@ -155,130 +73,170 @@ func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) { } } - fmt.Printf("parsed tool calls: count=%d\n", len(toolCalls)) - return toolCalls, false, true + // Valid JSON, no tool calls found + if len(toolCalls) == 0 { + return nil, false + } + + return toolCalls, false } -// prefix stripped string if any, prefix found, and if we should accumulate +// checkPrefix processes a string to find and handle a prefix pattern. +// +// Returns: +// - The processed string with prefix removed if found +// - Whether the prefix was found at the start of the string +// - Whether to continue parsing func (p *Parser) checkPrefix(s string) (string, bool, bool) { + // Keep original for overlap checks + original := s + s = strings.TrimSpace(s) + if s == "" { + return "", false, true + } + // If no prefix defined, just return trimmed string if p.prefix == "" { return s, false, true } - original := s - s = strings.TrimSpace(s) - s, hasPrefix := strings.CutPrefix(s, p.prefix) - if hasPrefix { - // partial tool possibly - accumulate - return s, true, true - } else if overlap := suffixOverlap(original, p.prefix); overlap > 0 { - // p.state = PartialPrefix - p.partialPrefix = true - return original[0 : len(original)-overlap], false, false - } else if idx := strings.Index(original, p.prefix); idx != -1 { - // Found prefix in middle of string, keep only content before prefix - // accounts for spaces in prefix or suffix to avoid breaking cache - p.partialPrefix = true - p.sb.Reset() + // Check for prefix at start of string + if processedStr, hasPrefix := strings.CutPrefix(s, p.prefix); hasPrefix { + // Found prefix at start - accumulate for potential tool + return processedStr, true, true + } + + // Check if prefix overlaps end of string + if overlap := suffixOverlap(original, p.prefix); overlap > 0 { + p.prefixPartial = true + // Return everything except overlapping portion + p.sb.Reset() + p.sb.WriteString(original[len(original)-overlap:]) + return original[0 : len(original)-overlap], false, false + } + + // Check if prefix appears in middle of string + if idx := strings.Index(original, p.prefix); idx != -1 { + p.prefixPartial = true + // Save remainder starting at prefix for next pass + p.sb.Reset() p.sb.WriteString(strings.TrimSpace(original[idx:])) + // Return everything before prefix return original[:idx], false, false } - p.partialPrefix = false + // No prefix found + p.prefixPartial = false return s, false, true } +// Add processes a string input to parse tool calls and content. +// It handles prefix detection and JSON parsing to extract tool calls. +// +// Returns: +// - tools: Any parsed tool calls +// - content: Non-tool call content +// - err: Error if parsing failed func (p *Parser) Add(s string) (tools []api.ToolCall, content string, err error) { - slog.Debug("adding tool calls", "input", s) - p.sb.WriteString(s) s = p.sb.String() - if len(s) == 0 { return nil, "", nil } - s, prefixFound, cont := p.checkPrefix(s) - - if !cont { + // Check for prefix pattern in input + s, prefixFound, shouldContinue := p.checkPrefix(s) + if !shouldContinue { if s != "" { - // send only the content back, prefix exists + // Return content before prefix return nil, s, nil } - // accumulate case + // Need more input to complete prefix return nil, "", nil } - // circuit breaker + // Update prefix found state if prefixFound { p.prefixFound = true } - // for cases with a prefix in template - if p.prefix != "" && !p.greedy && !p.prefixFound { - // send tokens down + // Exit if prefix exists in template, greedy parsing is off, and prefix not found + if !p.greedyParse && !p.prefixFound { p.sb.Reset() return nil, "", errors.New("prefix not found") } - // we have a prefix or are in json mode - tcs, partial, ok := p.parseJSONToolCalls(s) - if partial { - // accumulate case + + toolCalls, isPartial := p.parseJSONToolCalls(s) + if isPartial { + // Need more input to complete JSON return nil, "", nil } - p.greedy = false - if !ok { - // will not be a partial at this point + // Do not try greedy parsing if partial JSON not found + p.greedyParse = false + + // Handle invalid tool call format + if len(toolCalls) == 0 { p.sb.Reset() - // send tokens if p.prefix == "" { p.Done = true } if p.prefixFound { - // drop tokens instead - sb is reset, no tokens sent to user + // Drop tokens since prefix was found return nil, "", nil } - return nil, "", errors.New("failed to parse tool calls") + return nil, s, nil } - for _, tc := range tcs { + for _, tc := range toolCalls { tc.Function.Index = p.index p.index++ } + + // Mark as done if no prefix needed if p.prefix == "" { p.Done = true } + p.sb.Reset() - return tcs, "", nil + return toolCalls, "", nil } +// NewParser creates a new tool call parser from a template. It extracts the tool call format, +// prefix, and field names from the template to use for parsing tool calls from model output. +// +// Returns an error if the template does not contain valid tool call formatting. func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) { - parsedTemplate, err := template.Parse(templateToProcess.Root.String()) + parsed, err := template.Parse(templateToProcess.Root.String()) if err != nil { return nil, err } - if parsedTemplate == nil { + if parsed == nil { return nil, errors.New("failed to parse template") } - toolCallTemplate, hasToolCalls := toolTemplate(parsedTemplate) - if !hasToolCalls { - return nil, errors.New("failed to find tool template") + tt, tc := toolTemplate(parsed) + if !tc { + return nil, errors.New("failed to find tool calls in template") } - if toolCallTemplate == nil { + if tt == nil { return nil, errors.New("failed to find tool template") } - toolPrefix, _ := ToolPrefix(templateToProcess) - toolPrefix = strings.TrimSpace(toolPrefix) + tp := toolPrefix(templateToProcess) + tp = strings.TrimSpace(tp) + + name, arguments, err := extractToolArgs(tt) + if err != nil { + return nil, err + } - fmt.Printf("creating new tool parser: prefix=%s\n", toolPrefix) return &Parser{ - tmpl: toolCallTemplate, - sb: &strings.Builder{}, - prefix: toolPrefix, - greedy: true, + tmpl: tt, + sb: &strings.Builder{}, + prefix: tp, + greedyParse: true, + name: name, + arguments: arguments, }, nil } diff --git a/tools/tools_test.go b/tools/tools_test.go index 597d0bdd1..bc436f838 100644 --- a/tools/tools_test.go +++ b/tools/tools_test.go @@ -6,11 +6,8 @@ import ( "fmt" "os" "path/filepath" - "slices" "strings" "testing" - gotmpl "text/template" - "text/template/parse" "github.com/google/go-cmp/cmp" @@ -206,6 +203,27 @@ func TestParseToolCalls(t *testing.T) { expectedToolCall: []api.ToolCall{t1, t2}, expectedTokens: "", }, + { + name: "qwen2.5 tool calls without prefix and valid tool call", + model: "qwen2.5-coder", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + }, + { + name: "qwen2.5 tool calls without prefix and invalid tool call", + model: "qwen2.5-coder", + output: `[{"options": "foo"}]`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: `[{"options": "foo"}]`, + }, + { + name: "qwen2.5 tool calls with prefix and invalid tool call", + model: "qwen2.5-coder", + output: ` [{"options": "foo"}] `, + expectedToolCall: []api.ToolCall{}, + expectedTokens: ``, + }, { name: "qwen3 tool call with think prefix and tool prefix (sent as a single token)", model: "qwen3", @@ -239,14 +257,14 @@ func TestParseToolCalls(t *testing.T) { model: "qwen3", output: `< fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, expectedToolCall: []api.ToolCall{}, - expectedTokens: ` fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + expectedTokens: `< fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, }, { name: "qwen3 invalid tool call with partial tool prefix (multiple rune suffix match)", model: "qwen3", output: ``, expectedToolCall: []api.ToolCall{}, - expectedTokens: ` fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + expectedTokens: ``, }, { name: "qwen3 invalid tool call with malformed tool prefix", @@ -332,6 +350,7 @@ func TestParseToolCalls(t *testing.T) { toolCalls, content, err := tp.Add(s) if err == nil { if content != "" { + fmt.Printf("content: %q\n", content) gotTokens.WriteString(content) add = false } else if len(toolCalls) > 0 { @@ -363,24 +382,101 @@ func TestParseToolCalls(t *testing.T) { } } -func toolTemplateHelper(t *testing.T, tmpl *template.Template) (*gotmpl.Template, bool) { - // create a subtree from the node that ranges over .ToolCalls - - tmpl2 := tmpl.Subtree(func(n parse.Node) bool { - if t, ok := n.(*parse.RangeNode); ok { - return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls") - } - - return false - }) - - if tmpl2.Root != nil { - t.Log("tmpl2", tmpl2.Root.String()) +func TestParseJSONToolCalls(t *testing.T) { + tests := []struct { + name string + input string + parser *Parser + wantToolCalls []api.ToolCall + wantPartial bool + wantValid bool + }{ + { + name: "valid single tool call", + input: `{"name": "test_tool", "arguments": {"arg1": "value1"}}`, + parser: &Parser{name: "name", arguments: "arguments"}, + wantToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "test_tool", + Arguments: map[string]any{ + "arg1": "value1", + }, + }, + }, + }, + wantPartial: false, + wantValid: true, + }, + { + name: "incomplete JSON", + input: `{"name": "test_tool", "arguments": {"arg1": `, + parser: &Parser{name: "name", arguments: "arguments"}, + wantToolCalls: nil, + wantPartial: true, + wantValid: false, + }, + { + name: "invalid JSON", + input: `not json at all`, + parser: &Parser{name: "name", arguments: "arguments"}, + wantToolCalls: nil, + wantPartial: false, + wantValid: false, + }, + { + name: "missing required fields", + input: `{"other": "field"}`, + parser: &Parser{name: "name", arguments: "arguments"}, + wantToolCalls: nil, + wantPartial: false, + wantValid: false, + }, + { + name: "multiple tool calls in array", + input: `[ + {"name": "tool1", "arguments": {"arg1": 1}}, + {"name": "tool2", "arguments": {"arg2": "value"}} + ]`, + parser: &Parser{name: "name", arguments: "arguments"}, + wantToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "tool1", + Arguments: map[string]any{ + "arg1": float64(1), + }, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "tool2", + Arguments: map[string]any{ + "arg2": "value", + }, + }, + }, + }, + wantPartial: false, + wantValid: true, + }, } - if tmpl2 == nil { - return nil, false - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotCalls, gotPartial := tt.parser.parseJSONToolCalls(tt.input) - return tmpl2, true + if gotPartial != tt.wantPartial { + t.Errorf("parseJSONToolCalls() partial = %v, want %v", gotPartial, tt.wantPartial) + } + + if len(gotCalls) != 0 != tt.wantValid { + t.Errorf("parseJSONToolCalls() valid = %v, want %v", len(gotCalls) == 0, tt.wantValid) + } + + if diff := cmp.Diff(gotCalls, tt.wantToolCalls); diff != "" { + t.Errorf("parseJSONToolCalls() tool calls mismatch (-got +want):\n%s", diff) + } + }) + } } diff --git a/tools/utils.go b/tools/utils.go index 6e37b9010..64f88658c 100644 --- a/tools/utils.go +++ b/tools/utils.go @@ -1,17 +1,27 @@ package tools import ( + "bytes" + "errors" "log/slog" "slices" "strings" gotmpl "text/template" "text/template/parse" + jsonv2 "github.com/go-json-experiment/json" + "github.com/ollama/ollama/api" "github.com/ollama/ollama/template" ) -// extractToolCallsTemplate finds the immediate following text after any IfNode containing ".ToolCalls" -func extractToolCallsTemplate(tmpl *gotmpl.Template) (string, bool) { +// extractToolCallsFormat traverses a template AST to find text that follows a ".ToolCalls" condition. +// It walks the template nodes looking for if-statements containing ".ToolCalls" and extracts any +// immediate text nodes that follow. This is used to identify tool call prefixes and formatting. +// +// Returns: +// - string: The extracted text following the first ".ToolCalls" condition found +// - bool: Whether a ".ToolCalls" condition was found in the template +func extractToolCallsFormat(tmpl *gotmpl.Template) (string, bool) { if tmpl == nil || tmpl.Tree == nil { slog.Debug("TextAfterToolCalls: template or tree is nil") return "", false @@ -29,7 +39,7 @@ func extractToolCallsTemplate(tmpl *gotmpl.Template) (string, bool) { switch n := node.(type) { case *parse.IfNode: - if nodeContainsToolCalls(n) { + if isToolCallsNode(n) { // Collect immediate TextNode(s) at start of IfNode's list var sb strings.Builder for _, innerNode := range n.List.Nodes { @@ -76,8 +86,8 @@ func extractToolCallsTemplate(tmpl *gotmpl.Template) (string, bool) { return result, found } -// Helper to detect if a node's condition includes ".ToolCalls" -func nodeContainsToolCalls(n *parse.IfNode) bool { +// isToolCallsNode detects if a node's condition includes ".ToolCalls" +func isToolCallsNode(n *parse.IfNode) bool { for _, cmd := range n.Pipe.Cmds { for _, arg := range cmd.Args { if field, ok := arg.(*parse.FieldNode); ok { @@ -90,16 +100,17 @@ func nodeContainsToolCalls(n *parse.IfNode) bool { return false } -// ToolPrefix returns the prefix for the tool call if it exists // TODO(parthsareen): get full prefix from the template instead of just the first token -func ToolPrefix(tmpl *gotmpl.Template) (string, bool) { - tokenText, ok := extractToolCallsTemplate(tmpl) + +// toolPrefix returns the prefix for the tool call if it exists from a template +func toolPrefix(tmpl *gotmpl.Template) string { + tokenText, ok := extractToolCallsFormat(tmpl) if !ok { - return "", false + return "" } tokenText = strings.TrimSpace(tokenText) if tokenText == "" { - return "", false + return "" } first := strings.Fields(tokenText)[0] @@ -116,19 +127,23 @@ func ToolPrefix(tmpl *gotmpl.Template) (string, bool) { } if start != -1 && end != -1 { // return the token including the [ or < and the ] or > - return tokenText[start : end+1], true + return tokenText[start : end+1] } else if start != -1 { // get until the [ or < - in the case tag was not closed - return tokenText[:start], true + return tokenText[:start] } else if end != -1 { // get after the ] or > - in the case tag was not opened - return tokenText[end+1:], true + return tokenText[end+1:] } - return first, true + return first } +// toolTemplate creates a subtree from the node that ranges over .ToolCalls +// +// Returns: +// - *gotmpl.Template: The subtree containing the .ToolCalls range +// - bool: Whether a .ToolCalls range was found in the template func toolTemplate(t *template.Template) (*gotmpl.Template, bool) { - // create a subtree from the node that ranges over .ToolCalls tmpl := t.Subtree(func(n parse.Node) bool { if t, ok := n.(*parse.RangeNode); ok { return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls") @@ -144,6 +159,10 @@ func toolTemplate(t *template.Template) (*gotmpl.Template, bool) { return tmpl, true } +// suffixOverlap returns the length of the longest suffix overlap between two strings +// +// Returns: +// - int: The length of the longest suffix overlap func suffixOverlap(s, delim string) int { max := min(len(delim), len(s)) for i := max; i > 0; i-- { @@ -153,3 +172,86 @@ func suffixOverlap(s, delim string) int { } return 0 } + +// extractToolArgs executes a template with a known tool call format to extract the name and arguments +// +// Returns: +// - string: The name of the tool call +// - string: The arguments of the tool call +// - error: Error if parsing failed +func extractToolArgs(tmpl *gotmpl.Template) (name, arguments string, err error) { + var b bytes.Buffer + if err := tmpl.Execute(&b, map[string][]api.ToolCall{ + "ToolCalls": { + { + Function: api.ToolCallFunction{ + Name: "@@name@@", + Arguments: api.ToolCallFunctionArguments{ + "@@argument@@": 1, + }, + }, + }, + }, + }); err != nil { + return "", "", err + } + + var obj any + err = jsonv2.Unmarshal(b.Bytes(), &obj) + if err != nil { + return "", "", err + } + + var objs []map[string]any + switch v := obj.(type) { + case map[string]any: + objs = []map[string]any{v} + case []map[string]any: + objs = v + case []any: + objs = collect(v) + } + if len(objs) == 0 { + return "", "", errors.New("no template objects found") + } + + // find the keys that correspond to the name and arguments fields + for k, v := range objs[0] { + switch v.(type) { + case string: + name = k + case map[string]any: + arguments = k + } + } + + if name == "" || arguments == "" { + slog.Debug("missing required fields in tool call template", "name", name, "arguments", arguments) + return "", "", errors.New("missing required fields in tool call template") + } + + return name, arguments, nil +} + +// collect recursively traverses an object to collect all nested maps +// +// Returns: +// - []map[string]any: A slice of all nested maps found in the object +func collect(obj any) []map[string]any { + var all []map[string]any + switch o := obj.(type) { + case map[string]any: + all = append(all, o) + for _, v := range o { + all = append(all, collect(v)...) + } + case []any: + for _, v := range o { + all = append(all, collect(v)...) + } + default: + return nil + } + + return all +} diff --git a/tools/utils_test.go b/tools/utils_test.go index 4c37ecd40..c082fde02 100644 --- a/tools/utils_test.go +++ b/tools/utils_test.go @@ -3,74 +3,133 @@ package tools import ( "testing" gotmpl "text/template" + + "github.com/ollama/ollama/template" ) +func TestExtractToolCallsFormat(t *testing.T) { + cases := []struct { + name string + template string + want string + found bool + }{ + { + name: "nil template", + template: "", + want: "", + found: false, + }, + { + name: "basic tool call with text", + template: "{{if .ToolCalls}}Hello world{{end}}", + want: "Hello world", + found: true, + }, + { + name: "tool call with json format", + template: "{{if .ToolCalls}}```json\n{{end}}", + want: "```json\n", + found: true, + }, + { + name: "tool call in range", + template: "{{range .ToolCalls}}tool: {{.}}{{end}}", + want: "", + found: false, + }, + { + name: "tool call with multiple text nodes", + template: "{{if .ToolCalls}}First text{{if .Something}}inner{{end}}Second text{{end}}", + want: "First text", + found: true, + }, + { + name: "nested if without tool calls", + template: "{{if .Something}}{{if .OtherThing}}text{{end}}{{end}}", + want: "", + found: false, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + tmpl, err := gotmpl.New("test").Parse(tc.template) + if err != nil && tc.template != "" { + t.Fatalf("failed to parse template: %v", err) + } + + got, found := extractToolCallsFormat(tmpl) + if got != tc.want { + t.Errorf("got text %q, want %q", got, tc.want) + } + if found != tc.found { + t.Errorf("got found %v, want %v", found, tc.found) + } + }) + } +} + func TestToolPrefix(t *testing.T) { cases := []struct { name string template string want string - ok bool }{ { name: "basic tool call with action prefix", template: "{{if .ToolCalls}}Action: ```json{{end}}", want: "Action:", - ok: true, }, { name: "incomplete functools bracket", template: "{{if .ToolCalls}}functools[{{end}}", want: "functools", - ok: true, }, { name: "tool call with angle brackets", template: "{{if .ToolCalls}}Hello, world! {{end}}", want: "", - ok: true, }, { name: "multiple tool call formats", template: "{{if .ToolCalls}}[tool_call] {{end}}", want: "[tool_call]", - ok: true, }, { name: "single angle bracket tool call", template: "{{if .ToolCalls}}{{end}}", want: "", - ok: true, }, { name: "incomplete angle bracket after tool call", template: "{{if .ToolCalls}}[tool_call] <{{end}}", want: "[tool_call]", - ok: true, }, { name: "angle bracket prefix with tool call", template: "{{if .ToolCalls}}> {{end}}", want: "", - ok: true, }, { name: "uppercase tool call with incomplete bracket", template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}", want: "[TOOL_CALL]", - ok: true, }, { name: "uppercase tool call with adjacent bracket", template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}", want: "[TOOL_CALL]", - ok: true, }, { name: "tool call with pipe delimiters", template: "{{if .ToolCalls}}<|tool_call|>{{end}}", want: "<|tool_call|>", - ok: true, + }, + { + name: "tool with no prefix", + template: "{{if .ToolCalls}}{{end}}", + want: "", }, } @@ -80,18 +139,135 @@ func TestToolPrefix(t *testing.T) { if err != nil { t.Fatalf("failed to parse template: %v", err) } - got, ok := ToolPrefix(tmpl) + got := toolPrefix(tmpl) if got != tt.want { t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want) } - if ok != tt.ok { - t.Errorf("ToolToken(%q) = %v; want %v", tt.template, ok, tt.ok) - } }) } } -func TestTextAfterToolCalls(t *testing.T) { +func TestToolTemplate(t *testing.T) { + cases := []struct { + name string + template string + want bool + }{ + { + name: "basic tool call range", + template: "{{range .ToolCalls}}test{{end}}", + want: true, + }, + { + name: "no tool calls", + template: "{{range .Other}}test{{end}}", + want: false, + }, + { + name: "nested tool calls", + template: "{{range .Outer}}{{range .ToolCalls}}test{{end}}{{end}}", + want: true, + }, + { + name: "empty template", + template: "", + want: false, + }, + { + name: "tool calls in if statement", + template: "{{if .ToolCalls}}test{{end}}", + want: false, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + tmpl, err := gotmpl.New("test").Parse(tt.template) + if err != nil { + t.Fatalf("failed to parse template: %v", err) + } + + parsed, err := template.Parse(tmpl.Root.String()) + if err != nil { + t.Fatalf("failed to parse template: %v", err) + } + + _, got := toolTemplate(parsed) + if got != tt.want { + t.Errorf("toolTemplate() = %v; want %v", got, tt.want) + } + }) + } +} + +func TestSuffixOverlap(t *testing.T) { + cases := []struct { + name string + s string + d string + want int + }{ + { + name: "no overlap", + s: "hello world", + d: "", + want: 0, + }, + { + name: "full overlap", + s: "", + d: "", + want: 11, + }, + { + name: "partial overlap", + s: "text ", + d: "", + want: 11, + }, + { + name: "delimiter longer than string", + s: "", + d: "", + want: 0, + }, + { + name: "empty string", + s: "", + d: "", + want: 0, + }, + { + name: "empty delimiter", + s: "", + d: "", + want: 0, + }, + { + name: "single char overlap", + s: "test<", + d: "", + want: 1, + }, + { + name: "partial tool call", + s: "hello ", + want: 6, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got := suffixOverlap(tt.s, tt.d) + if got != tt.want { + t.Errorf("suffixOverlap(%q, %q) = %d; want %d", tt.s, tt.d, got, tt.want) + } + }) + } +} + +func TestExtractToolArgs(t *testing.T) { cases := []struct { name string template string @@ -173,7 +349,7 @@ func TestTextAfterToolCalls(t *testing.T) { t.Fatalf("failed to parse template: %v", err) } - got, ok := extractToolCallsTemplate(tmpl) + got, ok := extractToolCallsFormat(tmpl) if got != tt.want { t.Errorf("TextAfterToolCalls() got = %q, want %q", got, tt.want) } @@ -183,3 +359,106 @@ func TestTextAfterToolCalls(t *testing.T) { }) } } + +func TestCollect(t *testing.T) { + cases := []struct { + name string + obj any + want []map[string]any + }{ + { + name: "simple map", + obj: map[string]any{ + "key": "value", + }, + want: []map[string]any{ + {"key": "value"}, + }, + }, + { + name: "nested map", + obj: map[string]any{ + "outer": map[string]any{ + "inner": "value", + }, + }, + want: []map[string]any{ + {"outer": map[string]any{"inner": "value"}}, + {"inner": "value"}, + }, + }, + { + name: "array of maps", + obj: []any{ + map[string]any{"key1": "val1"}, + map[string]any{"key2": "val2"}, + }, + want: []map[string]any{ + {"key1": "val1"}, + {"key2": "val2"}, + }, + }, + { + name: "deeply nested", + obj: map[string]any{ + "l1": map[string]any{ + "l2": map[string]any{ + "l3": "value", + }, + }, + }, + want: []map[string]any{ + {"l1": map[string]any{"l2": map[string]any{"l3": "value"}}}, + {"l2": map[string]any{"l3": "value"}}, + {"l3": "value"}, + }, + }, + { + name: "non-map value", + obj: "string", + want: nil, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got := collect(tt.obj) + if len(got) != len(tt.want) { + t.Errorf("collect() got %d maps, want %d", len(got), len(tt.want)) + return + } + + // Compare each map in the result + for i := range tt.want { + if !mapsEqual(got[i], tt.want[i]) { + t.Errorf("collect() map[%d] = %v, want %v", i, got[i], tt.want[i]) + } + } + }) + } +} + +// mapsEqual compares two maps for deep equality +func mapsEqual(m1, m2 map[string]any) bool { + if len(m1) != len(m2) { + return false + } + for k, v1 := range m1 { + v2, ok := m2[k] + if !ok { + return false + } + switch val1 := v1.(type) { + case map[string]any: + val2, ok := v2.(map[string]any) + if !ok || !mapsEqual(val1, val2) { + return false + } + default: + if v1 != v2 { + return false + } + } + } + return true +}