tools package and utils

This commit is contained in:
ParthSareen 2025-05-12 18:02:18 -07:00
parent 4059b8db01
commit bc83789be9
7 changed files with 667 additions and 592 deletions

View File

@ -10,9 +10,6 @@ import (
"log/slog" "log/slog"
"net/http" "net/http"
"os" "os"
"slices"
gotmpl "text/template"
"text/template/parse"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
@ -129,19 +126,19 @@ func detectContentType(r io.Reader) (string, error) {
return "unknown", nil return "unknown", nil
} }
func ToolTemplate(m *Model) (*gotmpl.Template, bool) { // func ToolTemplate(m *Model) (*gotmpl.Template, bool) {
// create a subtree from the node that ranges over .ToolCalls // // create a subtree from the node that ranges over .ToolCalls
tmpl := m.Template.Subtree(func(n parse.Node) bool { // tmpl := m.Template.Subtree(func(n parse.Node) bool {
if t, ok := n.(*parse.RangeNode); ok { // if t, ok := n.(*parse.RangeNode); ok {
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls") // return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
} // }
return false // return false
}) // })
if tmpl == nil { // if tmpl == nil {
return nil, false // return nil, false
} // }
return tmpl, true // return tmpl, true
} // }

View File

@ -1,185 +1,185 @@
package server package server
import ( // import (
"testing" // "testing"
gotmpl "text/template" // gotmpl "text/template"
) // )
func TestToolToken(t *testing.T) { // func TestToolToken(t *testing.T) {
cases := []struct { // cases := []struct {
name string // name string
template string // template string
want string // want string
ok bool // ok bool
}{ // }{
{ // {
name: "basic tool call with action prefix", // name: "basic tool call with action prefix",
template: "{{if .ToolCalls}}Action: ```json{{end}}", // template: "{{if .ToolCalls}}Action: ```json{{end}}",
want: "Action:", // want: "Action:",
ok: true, // ok: true,
}, // },
{ // {
name: "incomplete functools bracket", // name: "incomplete functools bracket",
template: "{{if .ToolCalls}}functools[{{end}}", // template: "{{if .ToolCalls}}functools[{{end}}",
want: "functools", // want: "functools",
ok: true, // ok: true,
}, // },
{ // {
name: "tool call with angle brackets", // name: "tool call with angle brackets",
template: "{{if .ToolCalls}}Hello, world! <tool_call>{{end}}", // template: "{{if .ToolCalls}}Hello, world! <tool_call>{{end}}",
want: "<tool_call>", // want: "<tool_call>",
ok: true, // ok: true,
}, // },
{ // {
name: "multiple tool call formats", // name: "multiple tool call formats",
template: "{{if .ToolCalls}}[tool_call] <tool_call>{{end}}", // template: "{{if .ToolCalls}}[tool_call] <tool_call>{{end}}",
want: "[tool_call]", // want: "[tool_call]",
ok: true, // ok: true,
}, // },
{ // {
name: "single angle bracket tool call", // name: "single angle bracket tool call",
template: "{{if .ToolCalls}}<tool_call>{{end}}", // template: "{{if .ToolCalls}}<tool_call>{{end}}",
want: "<tool_call>", // want: "<tool_call>",
ok: true, // ok: true,
}, // },
{ // {
name: "incomplete angle bracket after tool call", // name: "incomplete angle bracket after tool call",
template: "{{if .ToolCalls}}[tool_call] <{{end}}", // template: "{{if .ToolCalls}}[tool_call] <{{end}}",
want: "[tool_call]", // want: "[tool_call]",
ok: true, // ok: true,
}, // },
{ // {
name: "angle bracket prefix with tool call", // name: "angle bracket prefix with tool call",
template: "{{if .ToolCalls}}> <tool_call>{{end}}", // template: "{{if .ToolCalls}}> <tool_call>{{end}}",
want: "<tool_call>", // want: "<tool_call>",
ok: true, // ok: true,
}, // },
{ // {
name: "uppercase tool call with incomplete bracket", // name: "uppercase tool call with incomplete bracket",
template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}", // template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}",
want: "[TOOL_CALL]", // want: "[TOOL_CALL]",
ok: true, // ok: true,
}, // },
{ // {
name: "uppercase tool call with adjacent bracket", // name: "uppercase tool call with adjacent bracket",
template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}", // template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}",
want: "[TOOL_CALL]", // want: "[TOOL_CALL]",
ok: true, // ok: true,
}, // },
{ // {
name: "tool call with pipe delimiters", // name: "tool call with pipe delimiters",
template: "{{if .ToolCalls}}<|tool_call|>{{end}}", // template: "{{if .ToolCalls}}<|tool_call|>{{end}}",
want: "<|tool_call|>", // want: "<|tool_call|>",
ok: true, // ok: true,
}, // },
} // }
for _, tt := range cases { // for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) { // t.Run(tt.name, func(t *testing.T) {
tmpl, err := gotmpl.New("test").Parse(tt.template) // tmpl, err := gotmpl.New("test").Parse(tt.template)
if err != nil { // if err != nil {
t.Fatalf("failed to parse template: %v", err) // t.Fatalf("failed to parse template: %v", err)
} // }
got, ok := ToolPrefix(tmpl) // got, ok := ToolPrefix(tmpl)
if got != tt.want { // if got != tt.want {
t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want) // t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want)
} // }
if ok != tt.ok { // if ok != tt.ok {
t.Errorf("ToolToken(%q) = %v; want %v", tt.template, ok, tt.ok) // t.Errorf("ToolToken(%q) = %v; want %v", tt.template, ok, tt.ok)
} // }
}) // })
} // }
} // }
func TestTextAfterToolCalls(t *testing.T) { // func TestTextAfterToolCalls(t *testing.T) {
cases := []struct { // cases := []struct {
name string // name string
template string // template string
want string // want string
ok bool // ok bool
}{ // }{
{ // {
name: "basic tool call with text after", // name: "basic tool call with text after",
template: `{{if .ToolCalls}}tool response{{end}}`, // template: `{{if .ToolCalls}}tool response{{end}}`,
want: "tool response", // want: "tool response",
ok: true, // ok: true,
}, // },
{ // {
name: "tool call with mixed content after", // name: "tool call with mixed content after",
template: `{{if .ToolCalls}}<tool_call>{{.Something}}{{end}}`, // template: `{{if .ToolCalls}}<tool_call>{{.Something}}{{end}}`,
want: "<tool_call>", // want: "<tool_call>",
ok: true, // ok: true,
}, // },
{ // {
name: "tool call with no text after", // name: "tool call with no text after",
template: `{{if .ToolCalls}}{{.Something}}{{end}}`, // template: `{{if .ToolCalls}}{{.Something}}{{end}}`,
want: "", // want: "",
ok: true, // ok: true,
}, // },
{ // {
name: "nested tool call", // name: "nested tool call",
template: `{{if .Something}}{{if .ToolCalls}}[TOOL_CALL]{{end}}{{end}}`, // template: `{{if .Something}}{{if .ToolCalls}}[TOOL_CALL]{{end}}{{end}}`,
want: "[TOOL_CALL]", // want: "[TOOL_CALL]",
ok: true, // ok: true,
}, // },
{ // {
name: "no tool calls", // name: "no tool calls",
template: `{{if .Something}}no tools here{{end}}`, // template: `{{if .Something}}no tools here{{end}}`,
want: "", // want: "",
ok: false, // ok: false,
}, // },
{ // {
name: "empty template", // name: "empty template",
template: ``, // template: ``,
want: "", // want: "",
ok: false, // ok: false,
}, // },
{ // {
name: "multiple tool calls sections", // name: "multiple tool calls sections",
template: `{{if .ToolCalls}}first{{end}}{{if .ToolCalls}}second{{end}}`, // template: `{{if .ToolCalls}}first{{end}}{{if .ToolCalls}}second{{end}}`,
want: "first", // want: "first",
ok: true, // ok: true,
}, // },
{ // {
name: "range over tool calls", // name: "range over tool calls",
template: `{{if .ToolCalls}}{{range .ToolCalls}}tool{{end}}{{end}}`, // template: `{{if .ToolCalls}}{{range .ToolCalls}}tool{{end}}{{end}}`,
want: "", // want: "",
ok: true, // ok: true,
}, // },
{ // {
name: "tool calls with pipe delimiters", // name: "tool calls with pipe delimiters",
template: `{{if .ToolCalls}}<|tool|>{{end}}`, // template: `{{if .ToolCalls}}<|tool|>{{end}}`,
want: "<|tool|>", // want: "<|tool|>",
ok: true, // ok: true,
}, // },
{ // {
name: "tool calls with nested template", // name: "tool calls with nested template",
template: `{{if .ToolCalls}}{{template "tool" .}}{{end}}`, // template: `{{if .ToolCalls}}{{template "tool" .}}{{end}}`,
want: "", // want: "",
ok: true, // ok: true,
}, // },
{ // {
name: "tool calls with whitespace variations", // name: "tool calls with whitespace variations",
template: `{{if .ToolCalls}} tool {{end}}`, // template: `{{if .ToolCalls}} tool {{end}}`,
want: " tool ", // want: " tool ",
ok: true, // ok: true,
}, // },
} // }
for _, tt := range cases { // for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) { // t.Run(tt.name, func(t *testing.T) {
tmpl, err := gotmpl.New("test").Parse(tt.template) // tmpl, err := gotmpl.New("test").Parse(tt.template)
if err != nil { // if err != nil {
t.Fatalf("failed to parse template: %v", err) // t.Fatalf("failed to parse template: %v", err)
} // }
got, ok := extractToolCallsTemplate(tmpl) // got, ok := extractToolCallsTemplate(tmpl)
if got != tt.want { // if got != tt.want {
t.Errorf("TextAfterToolCalls() got = %q, want %q", got, tt.want) // t.Errorf("TextAfterToolCalls() got = %q, want %q", got, tt.want)
} // }
if ok != tt.ok { // if ok != tt.ok {
t.Errorf("TextAfterToolCalls() ok = %v, want %v", ok, tt.ok) // t.Errorf("TextAfterToolCalls() ok = %v, want %v", ok, tt.ok)
} // }
}) // })
} // }
} // }

View File

@ -1483,19 +1483,21 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
slog.Debug("chat request", "images", len(images), "prompt", prompt)
var toolParser *tools.Parser
if len(req.Tools) > 0 {
toolParser, err = tools.NewParser(m.Template.Template)
if err != nil {
slog.Error("failed to create tool parser", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
ch := make(chan any) ch := make(chan any)
go func() { go func() {
defer close(ch) defer close(ch)
// ! personally not a fan of this pattern
toolTemplate, ok := ToolTemplate(m)
if !ok {
slog.Error("tool template not found", "model", m.Name)
return
}
var toolParser *tools.Parser
if len(req.Tools) > 0 {
toolParser = tools.NewParser(m.Template.Template, toolTemplate)
}
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
@ -1523,30 +1525,20 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
if len(req.Tools) > 0 && !toolParser.Done { if len(req.Tools) > 0 && !toolParser.Done {
toolCalls, leftover := toolParser.ParseToolCalls(r.Content) toolCalls, content, err := toolParser.Add(r.Content)
// * This can be abstracted again to a .handleState(tp.state) if err == nil {
// * However, we'd need a flag to indicate whether to send the response or not if len(content) > 0 {
// * happy to take whatever is more idiomatic res.Message.Content = content
switch toolParser.ParserState { fmt.Println("sending content in response", content)
case tools.ToolCallAccumulate: } else if len(toolCalls) > 0 {
// tokens are accumulated in the tool parser fmt.Println("sending tool calls in response", toolCalls)
return res.Message.ToolCalls = toolCalls
case tools.ToolCallSendTokens: res.Message.Content = ""
// tokens are sent back in the response } else {
case tools.ToolCallSendPartial: return
// tokens not needed for parsing are sent back in the response
if len(leftover) > 0 {
res.Message.Content = leftover
} }
// ! state is needed as we need to not match on the other states
case tools.ToolCallFound:
res.Message.ToolCalls = toolCalls
res.Message.Content = ""
} }
} }
fmt.Println("sending response", res.Message.Content)
// * this is where we'd need the flag if we have a .handleState(tp.state)
ch <- res ch <- res
}); err != nil { }); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}

View File

@ -6,91 +6,28 @@ import (
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"slices"
"strings" "strings"
gotmpl "text/template" gotmpl "text/template"
"text/template/parse"
jsonv2 "github.com/go-json-experiment/json" jsonv2 "github.com/go-json-experiment/json"
jsontext "github.com/go-json-experiment/json/jsontext" jsontext "github.com/go-json-experiment/json/jsontext"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
) )
type State int
// TODO: potentially coalesce states
const (
SendTokens State = iota
GreedyToolWithPrefix
GreedyToolNoPrefix
ForceTools
ToolSuffix
ContainsPrefix
PartialPrefix
NotPartialPrefix
Done
)
type ExternalState int
const (
ToolCallFound ExternalState = iota
ToolCallSendPartial
ToolCallAccumulate
ToolCallSendTokens
)
func (s ExternalState) String() string {
switch s {
case ToolCallFound:
return "ToolCallFound"
case ToolCallSendPartial:
return "ToolCallSendPartial"
case ToolCallAccumulate:
return "ToolCallAccumulate"
case ToolCallSendTokens:
return "ToolCallSendTokens"
default:
return fmt.Sprintf("Unknown ExternalState (%d)", s)
}
}
func (s State) String() string {
switch s {
case SendTokens:
return "SendTokens"
case GreedyToolWithPrefix:
return "GreedyToolWithPrefix"
case GreedyToolNoPrefix:
return "GreedyToolNoPrefix"
case ForceTools:
return "ForceTools"
case ToolSuffix:
return "ToolSuffix"
case PartialPrefix:
return "PossiblePrefix"
case Done:
return "Done"
case ContainsPrefix:
return "PartialPrefix"
default:
return fmt.Sprintf("Unknown State (%d)", s)
}
}
// TODO: simplify if possible // TODO: simplify if possible
type Parser struct { type Parser struct {
tmpl *gotmpl.Template greedy bool
state State prefixFound bool
sb *strings.Builder partialPrefix bool
toolPrefix string tmpl *gotmpl.Template
toolIndex int sb *strings.Builder
ParserState ExternalState prefix string
Done bool index int
Done bool
} }
// ? move to a separate file
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls. // parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
// Returns parsed tool calls, a boolean indicating if the JSON is incomplete, and a boolean indicating if the tool calls were found // Returns parsed tool calls, a boolean indicating if the JSON is incomplete, and a boolean indicating if the tool calls were found
func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) { func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
@ -222,314 +159,126 @@ func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
return toolCalls, false, true return toolCalls, false, true
} }
// TODO: clean up the boundary of internal and external state transitions // prefix stripped string if any, prefix found, and if we should accumulate
func (p *Parser) updateStateAfterJSONParse(ok bool, partial bool, tcs []api.ToolCall) { func (p *Parser) checkPrefix(s string) (string, bool, bool) {
fmt.Printf("updating output state: ok=%v partial=%v tool_calls=%d current_state=%s\n", ok, partial, len(tcs), p.state)
// state transition logic if p.prefix == "" {
switch { return s, false, true
case !ok && !partial && p.state == ForceTools:
// force partial tool if we have a prefix
// no op and stay in force tools
p.sb.Reset()
case !ok && !partial:
if p.state == GreedyToolNoPrefix {
p.state = Done
// ? the output parser state is the same even though internal can we not leak the external state?
p.Done = true
}
if p.state == GreedyToolWithPrefix {
p.state = SendTokens
}
if p.state == PartialPrefix {
p.state = NotPartialPrefix
}
case !ok && partial:
// acucumulate
case len(tcs) > 0:
// do not parse again in the greedy JSON case as soon as we have a tool call
p.sb.Reset()
}
p.updateExternalState(tcs)
fmt.Printf("state updated: new_state=%s parser_state=%s\n", p.state, p.ParserState)
}
func (p *Parser) updateExternalState(tcs []api.ToolCall) {
fmt.Printf("updating external state: current_state=%s tool_calls=%d\n", p.state, len(tcs))
switch {
case len(tcs) > 0:
// do not parse again in the greedy JSON case as soon as we have a tool call
if p.state == GreedyToolWithPrefix {
p.state = SendTokens
} else if p.state == GreedyToolNoPrefix {
p.state = Done
p.Done = true
}
p.ParserState = ToolCallFound
case p.state == GreedyToolWithPrefix || p.state == GreedyToolNoPrefix ||
p.state == ToolSuffix || p.state == PartialPrefix ||
(p.state == ForceTools && len(tcs) == 0):
p.ParserState = ToolCallAccumulate
case p.state == ContainsPrefix:
p.ParserState = ToolCallSendPartial
case p.state == SendTokens || p.state == Done:
p.ParserState = ToolCallSendTokens
case p.state == NotPartialPrefix:
p.ParserState = ToolCallSendPartial
default:
p.ParserState = ToolCallSendTokens
p.sb.Reset()
p.state = SendTokens
}
}
// string, and if it has a prefix
func (p *Parser) checkPrefix(s string) (string, bool) {
fmt.Printf("checking prefix: input=%s prefix=%s\n", s, p.toolPrefix)
if p.toolPrefix == "" {
return s, true
} }
original := s original := s
s, hasPrefix := strings.CutPrefix(s, p.toolPrefix) s = strings.TrimSpace(s)
s, hasPrefix := strings.CutPrefix(s, p.prefix)
if hasPrefix { if hasPrefix {
p.state = ForceTools
fmt.Printf("found exact prefix match: remaining=%s\n", s)
// partial tool possibly - accumulate // partial tool possibly - accumulate
} else if suffixOverlap(s, p.toolPrefix) > 0 { return s, true, true
p.state = PartialPrefix } else if overlap := suffixOverlap(original, p.prefix); overlap > 0 {
fmt.Printf("found partial prefix: remaining=%s\n", s) // p.state = PartialPrefix
return "", false p.partialPrefix = true
// the case where "token<tool_call>" - send "token" back return original[0 : len(original)-overlap], false, false
} else if idx := strings.Index(original, p.prefix); idx != -1 {
// Found prefix in middle of string, keep only content before prefix
// accounts for spaces in prefix or suffix to avoid breaking cache // accounts for spaces in prefix or suffix to avoid breaking cache
} else if strings.Contains(original, p.toolPrefix) { p.partialPrefix = true
idx := strings.Index(original, p.toolPrefix) p.sb.Reset()
if idx != -1 {
// still keeps the prefix p.sb.WriteString(strings.TrimSpace(original[idx:]))
p.state = ContainsPrefix return original[:idx], false, false
p.sb.Reset()
// todo: see if there is a simpler way for this
idx2 := strings.Index(s, p.toolPrefix)
// buffer now only has the prefix
p.sb.WriteString(s[idx2:])
fmt.Printf("found prefix in middle: prefix_start=%d content_before=%s\n", idx, original[:idx])
return original[:idx], false
}
} }
return s, true p.partialPrefix = false
return s, false, true
} }
// TODO: simplify the flow of this function func (p *Parser) Add(s string) (tools []api.ToolCall, content string, err error) {
// ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing. slog.Debug("adding tool calls", "input", s)
// Returns tool calls, whether parsing is incomplete, and any errors.
func (p *Parser) ParseToolCalls(s string) ([]api.ToolCall, string) {
fmt.Printf("parsing tool calls: input=%s current_state=%s\n", s, p.state)
p.sb.WriteString(s) p.sb.WriteString(s)
s = p.sb.String() s = p.sb.String()
s = strings.TrimSpace(s)
if len(s) == 0 { if len(s) == 0 {
p.updateExternalState(nil) return nil, "", nil
return nil, ""
} }
s, cont := p.checkPrefix(s) s, prefixFound, cont := p.checkPrefix(s)
if !cont { if !cont {
p.updateExternalState(nil) if s != "" {
if p.state == ContainsPrefix { // send only the content back, prefix exists
fmt.Printf("returning partial prefix: remaining=%s\n", s) return nil, s, nil
return nil, s
} }
// * we'd be returning here for just accumulating with possible prefix // accumulate case
// * ext state is accumulation return nil, "", nil
return nil, ""
} }
// * lets say the check fails here and now we're still in external state accumulation here
// stay in SendTokens unless we have a prefix // circuit breaker
if p.state == SendTokens { if prefixFound {
p.updateExternalState(nil) p.prefixFound = true
}
// for cases with a prefix in template
if p.prefix != "" && !p.greedy && !p.prefixFound {
// send tokens down
p.sb.Reset() p.sb.Reset()
fmt.Printf("returning send tokens: remaining=%s\n", s) return nil, "", errors.New("prefix not found")
return nil, s
} }
// we have a prefix or are in json mode
// * we'd parse here as json to see if it's a tool call
tcs, partial, ok := p.parseJSONToolCalls(s) tcs, partial, ok := p.parseJSONToolCalls(s)
// * it would not be a tool call here if partial {
p.updateStateAfterJSONParse(ok, partial, tcs) // accumulate case
if !ok { return nil, "", nil
// * and so we should send the data here
// * we also need to move out of that internal state after sending the tokens
if p.state == NotPartialPrefix {
p.state = SendTokens
// the string would have acc until here
return nil, p.sb.String()
}
return nil, ""
} }
p.greedy = false
if !ok {
// will not be a partial at this point
p.sb.Reset()
// send tokens
if p.prefix == "" {
p.Done = true
}
if p.prefixFound {
// drop tokens instead - sb is reset, no tokens sent to user
return nil, "", nil
}
return nil, "", errors.New("failed to parse tool calls")
}
for _, tc := range tcs { for _, tc := range tcs {
tc.Function.Index = p.toolIndex tc.Function.Index = p.index
p.toolIndex++ p.index++
} }
fmt.Printf("finished parsing tool calls: tool_calls_found=%d\n", len(tcs)) if p.prefix == "" {
return tcs, "" p.Done = true
}
p.sb.Reset()
return tcs, "", nil
} }
func suffixOverlap(s, delim string) int { func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) {
max := min(len(delim), len(s)) parsedTemplate, err := template.Parse(templateToProcess.Root.String())
for i := max; i > 0; i-- { if err != nil {
if strings.HasSuffix(s, delim[:i]) { return nil, err
return i
}
} }
return 0 if parsedTemplate == nil {
} return nil, errors.New("failed to parse template")
// extractToolCallsTemplate finds the immediate following text after any IfNode containing ".ToolCalls"
func extractToolCallsTemplate(tmpl *gotmpl.Template) (string, bool) {
if tmpl == nil || tmpl.Tree == nil {
slog.Debug("TextAfterToolCalls: template or tree is nil")
return "", false
} }
var result string toolCallTemplate, hasToolCalls := toolTemplate(parsedTemplate)
var found bool if !hasToolCalls {
return nil, errors.New("failed to find tool template")
var walk func(nodes []parse.Node) }
walk = func(nodes []parse.Node) { if toolCallTemplate == nil {
for _, node := range nodes { return nil, errors.New("failed to find tool template")
if found {
return
}
switch n := node.(type) {
case *parse.IfNode:
if nodeContainsToolCalls(n) {
// Collect immediate TextNode(s) at start of IfNode's list
var sb strings.Builder
for _, innerNode := range n.List.Nodes {
if tn, ok := innerNode.(*parse.TextNode); ok {
sb.Write(tn.Text)
} else {
// Stop at first non-text node
break
}
}
result = sb.String()
found = true
return
}
// Recurse into child nodes
walk(n.List.Nodes)
if n.ElseList != nil {
walk(n.ElseList.Nodes)
}
case *parse.ListNode:
walk(n.Nodes)
case *parse.RangeNode:
walk(n.List.Nodes)
if n.ElseList != nil {
walk(n.ElseList.Nodes)
}
case *parse.WithNode:
walk(n.List.Nodes)
if n.ElseList != nil {
walk(n.ElseList.Nodes)
}
default:
// Continue to next node
continue
}
if found {
return
}
}
} }
walk(tmpl.Tree.Root.Nodes) toolPrefix, _ := ToolPrefix(templateToProcess)
return result, found toolPrefix = strings.TrimSpace(toolPrefix)
}
// Helper to detect if a node's condition includes ".ToolCalls" fmt.Printf("creating new tool parser: prefix=%s\n", toolPrefix)
func nodeContainsToolCalls(n *parse.IfNode) bool {
for _, cmd := range n.Pipe.Cmds {
for _, arg := range cmd.Args {
if field, ok := arg.(*parse.FieldNode); ok {
if slices.Contains(field.Ident, "ToolCalls") {
return true
}
}
}
}
return false
}
func ToolPrefix(tmpl *gotmpl.Template) (string, bool) {
tokenText, ok := extractToolCallsTemplate(tmpl)
if !ok {
return "", false
}
tokenText = strings.TrimSpace(tokenText)
if tokenText == "" {
return "", false
}
first := strings.Fields(tokenText)[0]
start := -1
end := -1
for i, r := range tokenText {
if r == '<' || r == '[' {
start = i
}
if (r == '>' || r == ']') && start != -1 {
end = i
break
}
}
if start != -1 && end != -1 {
// return the token including the [ or < and the ] or >
return tokenText[start : end+1], true
} else if start != -1 {
// get until the [ or < - in the case tag was not closed
return tokenText[:start], true
} else if end != -1 {
// get after the ] or > - in the case tag was not opened
return tokenText[end+1:], true
}
return first, true
}
func NewParser(tmpl *gotmpl.Template, toolTemplate *gotmpl.Template) *Parser {
// TODO: use new template parsing to get all tokens for the prefix
if tmpl == nil {
return nil
}
if toolTemplate == nil {
return nil
}
prefix, _ := ToolPrefix(tmpl)
prefix = strings.TrimSpace(prefix)
var state State
if prefix == "" {
state = GreedyToolNoPrefix
} else {
state = GreedyToolWithPrefix
}
fmt.Printf("creating new tool parser: prefix=%s initial_state=%s\n", prefix, state)
return &Parser{ return &Parser{
tmpl: toolTemplate, tmpl: toolCallTemplate,
sb: &strings.Builder{}, sb: &strings.Builder{},
toolPrefix: prefix, prefix: toolPrefix,
state: state, greedy: true,
ParserState: ToolCallAccumulate, }, nil
}
} }

View File

@ -239,14 +239,14 @@ func TestParseToolCalls(t *testing.T) {
model: "qwen3", model: "qwen3",
output: `<think></think>< fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`, output: `<think></think>< fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
expectedToolCall: []api.ToolCall{}, expectedToolCall: []api.ToolCall{},
expectedTokens: `<think></think>< fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`, expectedTokens: `<think></think> fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
}, },
{ {
name: "qwen3 invalid tool call with partial tool prefix (multiple rune suffix match)", name: "qwen3 invalid tool call with partial tool prefix (multiple rune suffix match)",
model: "qwen3", model: "qwen3",
output: `<think></think><tool_c fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`, output: `<think></think><tool_c fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
expectedToolCall: []api.ToolCall{}, expectedToolCall: []api.ToolCall{},
expectedTokens: `<think></think><tool_c fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`, expectedTokens: `<think></think> fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
}, },
{ {
name: "qwen3 invalid tool call with malformed tool prefix", name: "qwen3 invalid tool call with malformed tool prefix",
@ -315,34 +315,31 @@ func TestParseToolCalls(t *testing.T) {
t.Run("parse", func(t *testing.T) { t.Run("parse", func(t *testing.T) {
// fmt.Printf("tmpl: %s\n", tmpl.Root.String()) // fmt.Printf("tmpl: %s\n", tmpl.Root.String())
toolTemplate, ok := toolTemplateHelper(t, tmpl) tp, err := NewParser(tmpl.Template)
if !ok { if err != nil {
t.Fatalf("tool template not found for model %s", tt.model) t.Fatal(err)
} }
tp := NewParser(tmpl.Template, toolTemplate)
got := []api.ToolCall{} got := []api.ToolCall{}
var gotTokens strings.Builder var gotTokens strings.Builder
var add bool
tokens := strings.Fields(tt.output) tokens := strings.Fields(tt.output)
for _, tok := range tokens { for _, tok := range tokens {
add := true
s := " " + tok s := " " + tok
add = true
if !tp.Done { if !tp.Done {
toolCalls, leftover := tp.ParseToolCalls(s) toolCalls, content, err := tp.Add(s)
switch tp.ParserState { if err == nil {
case ToolCallFound: if content != "" {
got = append(got, toolCalls...) gotTokens.WriteString(content)
add = false add = false
case ToolCallSendTokens: } else if len(toolCalls) > 0 {
gotTokens.WriteString(s) got = append(got, toolCalls...)
add = false add = false
case ToolCallAccumulate: } else {
add = false add = false
case ToolCallSendPartial: }
t.Log("send partial", "leftover", leftover)
gotTokens.WriteString(" " + leftover)
add = false
} }
} }
if add { if add {

155
tools/utils.go Normal file
View File

@ -0,0 +1,155 @@
package tools
import (
"log/slog"
"slices"
"strings"
gotmpl "text/template"
"text/template/parse"
"github.com/ollama/ollama/template"
)
// extractToolCallsTemplate finds the immediate following text after any IfNode containing ".ToolCalls"
func extractToolCallsTemplate(tmpl *gotmpl.Template) (string, bool) {
if tmpl == nil || tmpl.Tree == nil {
slog.Debug("TextAfterToolCalls: template or tree is nil")
return "", false
}
var result string
var found bool
var walk func(nodes []parse.Node)
walk = func(nodes []parse.Node) {
for _, node := range nodes {
if found {
return
}
switch n := node.(type) {
case *parse.IfNode:
if nodeContainsToolCalls(n) {
// Collect immediate TextNode(s) at start of IfNode's list
var sb strings.Builder
for _, innerNode := range n.List.Nodes {
if tn, ok := innerNode.(*parse.TextNode); ok {
sb.Write(tn.Text)
} else {
// Stop at first non-text node
break
}
}
result = sb.String()
found = true
return
}
// Recurse into child nodes
walk(n.List.Nodes)
if n.ElseList != nil {
walk(n.ElseList.Nodes)
}
case *parse.ListNode:
walk(n.Nodes)
case *parse.RangeNode:
walk(n.List.Nodes)
if n.ElseList != nil {
walk(n.ElseList.Nodes)
}
case *parse.WithNode:
walk(n.List.Nodes)
if n.ElseList != nil {
walk(n.ElseList.Nodes)
}
default:
// Continue to next node
continue
}
if found {
return
}
}
}
walk(tmpl.Tree.Root.Nodes)
return result, found
}
// Helper to detect if a node's condition includes ".ToolCalls"
func nodeContainsToolCalls(n *parse.IfNode) bool {
for _, cmd := range n.Pipe.Cmds {
for _, arg := range cmd.Args {
if field, ok := arg.(*parse.FieldNode); ok {
if slices.Contains(field.Ident, "ToolCalls") {
return true
}
}
}
}
return false
}
// ToolPrefix returns the prefix for the tool call if it exists
// TODO(parthsareen): get full prefix from the template instead of just the first token
func ToolPrefix(tmpl *gotmpl.Template) (string, bool) {
tokenText, ok := extractToolCallsTemplate(tmpl)
if !ok {
return "", false
}
tokenText = strings.TrimSpace(tokenText)
if tokenText == "" {
return "", false
}
first := strings.Fields(tokenText)[0]
start := -1
end := -1
for i, r := range tokenText {
if r == '<' || r == '[' {
start = i
}
if (r == '>' || r == ']') && start != -1 {
end = i
break
}
}
if start != -1 && end != -1 {
// return the token including the [ or < and the ] or >
return tokenText[start : end+1], true
} else if start != -1 {
// get until the [ or < - in the case tag was not closed
return tokenText[:start], true
} else if end != -1 {
// get after the ] or > - in the case tag was not opened
return tokenText[end+1:], true
}
return first, true
}
func toolTemplate(t *template.Template) (*gotmpl.Template, bool) {
// create a subtree from the node that ranges over .ToolCalls
tmpl := t.Subtree(func(n parse.Node) bool {
if t, ok := n.(*parse.RangeNode); ok {
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
}
return false
})
if tmpl == nil {
return nil, false
}
return tmpl, true
}
func suffixOverlap(s, delim string) int {
max := min(len(delim), len(s))
for i := max; i > 0; i-- {
if strings.HasSuffix(s, delim[:i]) {
return i
}
}
return 0
}

185
tools/utils_test.go Normal file
View File

@ -0,0 +1,185 @@
package tools
import (
"testing"
gotmpl "text/template"
)
func TestToolPrefix(t *testing.T) {
cases := []struct {
name string
template string
want string
ok bool
}{
{
name: "basic tool call with action prefix",
template: "{{if .ToolCalls}}Action: ```json{{end}}",
want: "Action:",
ok: true,
},
{
name: "incomplete functools bracket",
template: "{{if .ToolCalls}}functools[{{end}}",
want: "functools",
ok: true,
},
{
name: "tool call with angle brackets",
template: "{{if .ToolCalls}}Hello, world! <tool_call>{{end}}",
want: "<tool_call>",
ok: true,
},
{
name: "multiple tool call formats",
template: "{{if .ToolCalls}}[tool_call] <tool_call>{{end}}",
want: "[tool_call]",
ok: true,
},
{
name: "single angle bracket tool call",
template: "{{if .ToolCalls}}<tool_call>{{end}}",
want: "<tool_call>",
ok: true,
},
{
name: "incomplete angle bracket after tool call",
template: "{{if .ToolCalls}}[tool_call] <{{end}}",
want: "[tool_call]",
ok: true,
},
{
name: "angle bracket prefix with tool call",
template: "{{if .ToolCalls}}> <tool_call>{{end}}",
want: "<tool_call>",
ok: true,
},
{
name: "uppercase tool call with incomplete bracket",
template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}",
want: "[TOOL_CALL]",
ok: true,
},
{
name: "uppercase tool call with adjacent bracket",
template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}",
want: "[TOOL_CALL]",
ok: true,
},
{
name: "tool call with pipe delimiters",
template: "{{if .ToolCalls}}<|tool_call|>{{end}}",
want: "<|tool_call|>",
ok: true,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
tmpl, err := gotmpl.New("test").Parse(tt.template)
if err != nil {
t.Fatalf("failed to parse template: %v", err)
}
got, ok := ToolPrefix(tmpl)
if got != tt.want {
t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want)
}
if ok != tt.ok {
t.Errorf("ToolToken(%q) = %v; want %v", tt.template, ok, tt.ok)
}
})
}
}
func TestTextAfterToolCalls(t *testing.T) {
cases := []struct {
name string
template string
want string
ok bool
}{
{
name: "basic tool call with text after",
template: `{{if .ToolCalls}}tool response{{end}}`,
want: "tool response",
ok: true,
},
{
name: "tool call with mixed content after",
template: `{{if .ToolCalls}}<tool_call>{{.Something}}{{end}}`,
want: "<tool_call>",
ok: true,
},
{
name: "tool call with no text after",
template: `{{if .ToolCalls}}{{.Something}}{{end}}`,
want: "",
ok: true,
},
{
name: "nested tool call",
template: `{{if .Something}}{{if .ToolCalls}}[TOOL_CALL]{{end}}{{end}}`,
want: "[TOOL_CALL]",
ok: true,
},
{
name: "no tool calls",
template: `{{if .Something}}no tools here{{end}}`,
want: "",
ok: false,
},
{
name: "empty template",
template: ``,
want: "",
ok: false,
},
{
name: "multiple tool calls sections",
template: `{{if .ToolCalls}}first{{end}}{{if .ToolCalls}}second{{end}}`,
want: "first",
ok: true,
},
{
name: "range over tool calls",
template: `{{if .ToolCalls}}{{range .ToolCalls}}tool{{end}}{{end}}`,
want: "",
ok: true,
},
{
name: "tool calls with pipe delimiters",
template: `{{if .ToolCalls}}<|tool|>{{end}}`,
want: "<|tool|>",
ok: true,
},
{
name: "tool calls with nested template",
template: `{{if .ToolCalls}}{{template "tool" .}}{{end}}`,
want: "",
ok: true,
},
{
name: "tool calls with whitespace variations",
template: `{{if .ToolCalls}} tool {{end}}`,
want: " tool ",
ok: true,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
tmpl, err := gotmpl.New("test").Parse(tt.template)
if err != nil {
t.Fatalf("failed to parse template: %v", err)
}
got, ok := extractToolCallsTemplate(tmpl)
if got != tt.want {
t.Errorf("TextAfterToolCalls() got = %q, want %q", got, tt.want)
}
if ok != tt.ok {
t.Errorf("TextAfterToolCalls() ok = %v, want %v", ok, tt.ok)
}
})
}
}