From bc83789be9fdb9e41c061997e20ff5402f6ec6d1 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Mon, 12 May 2025 18:02:18 -0700 Subject: [PATCH] tools package and utils --- server/model.go | 29 ++- server/model_test.go | 356 +++++++++++++++++----------------- server/routes.go | 54 +++--- tools/tools.go | 441 ++++++++++--------------------------------- tools/tools_test.go | 39 ++-- tools/utils.go | 155 +++++++++++++++ tools/utils_test.go | 185 ++++++++++++++++++ 7 files changed, 667 insertions(+), 592 deletions(-) create mode 100644 tools/utils.go create mode 100644 tools/utils_test.go diff --git a/server/model.go b/server/model.go index 7e749829c..e9b57eb75 100644 --- a/server/model.go +++ b/server/model.go @@ -10,9 +10,6 @@ import ( "log/slog" "net/http" "os" - "slices" - gotmpl "text/template" - "text/template/parse" "github.com/ollama/ollama/api" "github.com/ollama/ollama/fs/ggml" @@ -129,19 +126,19 @@ func detectContentType(r io.Reader) (string, error) { return "unknown", nil } -func ToolTemplate(m *Model) (*gotmpl.Template, bool) { - // create a subtree from the node that ranges over .ToolCalls - tmpl := m.Template.Subtree(func(n parse.Node) bool { - if t, ok := n.(*parse.RangeNode); ok { - return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls") - } +// func ToolTemplate(m *Model) (*gotmpl.Template, bool) { +// // create a subtree from the node that ranges over .ToolCalls +// tmpl := m.Template.Subtree(func(n parse.Node) bool { +// if t, ok := n.(*parse.RangeNode); ok { +// return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls") +// } - return false - }) +// return false +// }) - if tmpl == nil { - return nil, false - } +// if tmpl == nil { +// return nil, false +// } - return tmpl, true -} +// return tmpl, true +// } diff --git a/server/model_test.go b/server/model_test.go index 8fd19d2db..7458b1dbc 100644 --- a/server/model_test.go +++ b/server/model_test.go @@ -1,185 +1,185 @@ package server -import ( - "testing" - gotmpl "text/template" -) +// import ( +// "testing" +// gotmpl "text/template" +// ) -func TestToolToken(t *testing.T) { - cases := []struct { - name string - template string - want string - ok bool - }{ - { - name: "basic tool call with action prefix", - template: "{{if .ToolCalls}}Action: ```json{{end}}", - want: "Action:", - ok: true, - }, - { - name: "incomplete functools bracket", - template: "{{if .ToolCalls}}functools[{{end}}", - want: "functools", - ok: true, - }, - { - name: "tool call with angle brackets", - template: "{{if .ToolCalls}}Hello, world! {{end}}", - want: "", - ok: true, - }, - { - name: "multiple tool call formats", - template: "{{if .ToolCalls}}[tool_call] {{end}}", - want: "[tool_call]", - ok: true, - }, - { - name: "single angle bracket tool call", - template: "{{if .ToolCalls}}{{end}}", - want: "", - ok: true, - }, - { - name: "incomplete angle bracket after tool call", - template: "{{if .ToolCalls}}[tool_call] <{{end}}", - want: "[tool_call]", - ok: true, - }, - { - name: "angle bracket prefix with tool call", - template: "{{if .ToolCalls}}> {{end}}", - want: "", - ok: true, - }, - { - name: "uppercase tool call with incomplete bracket", - template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}", - want: "[TOOL_CALL]", - ok: true, - }, - { - name: "uppercase tool call with adjacent bracket", - template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}", - want: "[TOOL_CALL]", - ok: true, - }, - { - name: "tool call with pipe delimiters", - template: "{{if .ToolCalls}}<|tool_call|>{{end}}", - want: "<|tool_call|>", - ok: true, - }, - } +// func TestToolToken(t *testing.T) { +// cases := []struct { +// name string +// template string +// want string +// ok bool +// }{ +// { +// name: "basic tool call with action prefix", +// template: "{{if .ToolCalls}}Action: ```json{{end}}", +// want: "Action:", +// ok: true, +// }, +// { +// name: "incomplete functools bracket", +// template: "{{if .ToolCalls}}functools[{{end}}", +// want: "functools", +// ok: true, +// }, +// { +// name: "tool call with angle brackets", +// template: "{{if .ToolCalls}}Hello, world! {{end}}", +// want: "", +// ok: true, +// }, +// { +// name: "multiple tool call formats", +// template: "{{if .ToolCalls}}[tool_call] {{end}}", +// want: "[tool_call]", +// ok: true, +// }, +// { +// name: "single angle bracket tool call", +// template: "{{if .ToolCalls}}{{end}}", +// want: "", +// ok: true, +// }, +// { +// name: "incomplete angle bracket after tool call", +// template: "{{if .ToolCalls}}[tool_call] <{{end}}", +// want: "[tool_call]", +// ok: true, +// }, +// { +// name: "angle bracket prefix with tool call", +// template: "{{if .ToolCalls}}> {{end}}", +// want: "", +// ok: true, +// }, +// { +// name: "uppercase tool call with incomplete bracket", +// template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}", +// want: "[TOOL_CALL]", +// ok: true, +// }, +// { +// name: "uppercase tool call with adjacent bracket", +// template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}", +// want: "[TOOL_CALL]", +// ok: true, +// }, +// { +// name: "tool call with pipe delimiters", +// template: "{{if .ToolCalls}}<|tool_call|>{{end}}", +// want: "<|tool_call|>", +// ok: true, +// }, +// } - for _, tt := range cases { - t.Run(tt.name, func(t *testing.T) { - tmpl, err := gotmpl.New("test").Parse(tt.template) - if err != nil { - t.Fatalf("failed to parse template: %v", err) - } - got, ok := ToolPrefix(tmpl) - if got != tt.want { - t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want) - } - if ok != tt.ok { - t.Errorf("ToolToken(%q) = %v; want %v", tt.template, ok, tt.ok) - } - }) - } -} +// for _, tt := range cases { +// t.Run(tt.name, func(t *testing.T) { +// tmpl, err := gotmpl.New("test").Parse(tt.template) +// if err != nil { +// t.Fatalf("failed to parse template: %v", err) +// } +// got, ok := ToolPrefix(tmpl) +// if got != tt.want { +// t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want) +// } +// if ok != tt.ok { +// t.Errorf("ToolToken(%q) = %v; want %v", tt.template, ok, tt.ok) +// } +// }) +// } +// } -func TestTextAfterToolCalls(t *testing.T) { - cases := []struct { - name string - template string - want string - ok bool - }{ - { - name: "basic tool call with text after", - template: `{{if .ToolCalls}}tool response{{end}}`, - want: "tool response", - ok: true, - }, - { - name: "tool call with mixed content after", - template: `{{if .ToolCalls}}{{.Something}}{{end}}`, - want: "", - ok: true, - }, - { - name: "tool call with no text after", - template: `{{if .ToolCalls}}{{.Something}}{{end}}`, - want: "", - ok: true, - }, - { - name: "nested tool call", - template: `{{if .Something}}{{if .ToolCalls}}[TOOL_CALL]{{end}}{{end}}`, - want: "[TOOL_CALL]", - ok: true, - }, - { - name: "no tool calls", - template: `{{if .Something}}no tools here{{end}}`, - want: "", - ok: false, - }, - { - name: "empty template", - template: ``, - want: "", - ok: false, - }, - { - name: "multiple tool calls sections", - template: `{{if .ToolCalls}}first{{end}}{{if .ToolCalls}}second{{end}}`, - want: "first", - ok: true, - }, - { - name: "range over tool calls", - template: `{{if .ToolCalls}}{{range .ToolCalls}}tool{{end}}{{end}}`, - want: "", - ok: true, - }, - { - name: "tool calls with pipe delimiters", - template: `{{if .ToolCalls}}<|tool|>{{end}}`, - want: "<|tool|>", - ok: true, - }, - { - name: "tool calls with nested template", - template: `{{if .ToolCalls}}{{template "tool" .}}{{end}}`, - want: "", - ok: true, - }, - { - name: "tool calls with whitespace variations", - template: `{{if .ToolCalls}} tool {{end}}`, - want: " tool ", - ok: true, - }, - } +// func TestTextAfterToolCalls(t *testing.T) { +// cases := []struct { +// name string +// template string +// want string +// ok bool +// }{ +// { +// name: "basic tool call with text after", +// template: `{{if .ToolCalls}}tool response{{end}}`, +// want: "tool response", +// ok: true, +// }, +// { +// name: "tool call with mixed content after", +// template: `{{if .ToolCalls}}{{.Something}}{{end}}`, +// want: "", +// ok: true, +// }, +// { +// name: "tool call with no text after", +// template: `{{if .ToolCalls}}{{.Something}}{{end}}`, +// want: "", +// ok: true, +// }, +// { +// name: "nested tool call", +// template: `{{if .Something}}{{if .ToolCalls}}[TOOL_CALL]{{end}}{{end}}`, +// want: "[TOOL_CALL]", +// ok: true, +// }, +// { +// name: "no tool calls", +// template: `{{if .Something}}no tools here{{end}}`, +// want: "", +// ok: false, +// }, +// { +// name: "empty template", +// template: ``, +// want: "", +// ok: false, +// }, +// { +// name: "multiple tool calls sections", +// template: `{{if .ToolCalls}}first{{end}}{{if .ToolCalls}}second{{end}}`, +// want: "first", +// ok: true, +// }, +// { +// name: "range over tool calls", +// template: `{{if .ToolCalls}}{{range .ToolCalls}}tool{{end}}{{end}}`, +// want: "", +// ok: true, +// }, +// { +// name: "tool calls with pipe delimiters", +// template: `{{if .ToolCalls}}<|tool|>{{end}}`, +// want: "<|tool|>", +// ok: true, +// }, +// { +// name: "tool calls with nested template", +// template: `{{if .ToolCalls}}{{template "tool" .}}{{end}}`, +// want: "", +// ok: true, +// }, +// { +// name: "tool calls with whitespace variations", +// template: `{{if .ToolCalls}} tool {{end}}`, +// want: " tool ", +// ok: true, +// }, +// } - for _, tt := range cases { - t.Run(tt.name, func(t *testing.T) { - tmpl, err := gotmpl.New("test").Parse(tt.template) - if err != nil { - t.Fatalf("failed to parse template: %v", err) - } +// for _, tt := range cases { +// t.Run(tt.name, func(t *testing.T) { +// tmpl, err := gotmpl.New("test").Parse(tt.template) +// if err != nil { +// t.Fatalf("failed to parse template: %v", err) +// } - got, ok := extractToolCallsTemplate(tmpl) - if got != tt.want { - t.Errorf("TextAfterToolCalls() got = %q, want %q", got, tt.want) - } - if ok != tt.ok { - t.Errorf("TextAfterToolCalls() ok = %v, want %v", ok, tt.ok) - } - }) - } -} +// got, ok := extractToolCallsTemplate(tmpl) +// if got != tt.want { +// t.Errorf("TextAfterToolCalls() got = %q, want %q", got, tt.want) +// } +// if ok != tt.ok { +// t.Errorf("TextAfterToolCalls() ok = %v, want %v", ok, tt.ok) +// } +// }) +// } +// } diff --git a/server/routes.go b/server/routes.go index 678af53f6..64823bd32 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1483,19 +1483,21 @@ func (s *Server) ChatHandler(c *gin.Context) { return } + slog.Debug("chat request", "images", len(images), "prompt", prompt) + + var toolParser *tools.Parser + if len(req.Tools) > 0 { + toolParser, err = tools.NewParser(m.Template.Template) + if err != nil { + slog.Error("failed to create tool parser", "error", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + ch := make(chan any) go func() { defer close(ch) - // ! personally not a fan of this pattern - toolTemplate, ok := ToolTemplate(m) - if !ok { - slog.Error("tool template not found", "model", m.Name) - return - } - var toolParser *tools.Parser - if len(req.Tools) > 0 { - toolParser = tools.NewParser(m.Template.Template, toolTemplate) - } if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: prompt, @@ -1523,30 +1525,20 @@ func (s *Server) ChatHandler(c *gin.Context) { } if len(req.Tools) > 0 && !toolParser.Done { - toolCalls, leftover := toolParser.ParseToolCalls(r.Content) - // * This can be abstracted again to a .handleState(tp.state) - // * However, we'd need a flag to indicate whether to send the response or not - // * happy to take whatever is more idiomatic - switch toolParser.ParserState { - case tools.ToolCallAccumulate: - // tokens are accumulated in the tool parser - return - case tools.ToolCallSendTokens: - // tokens are sent back in the response - case tools.ToolCallSendPartial: - // tokens not needed for parsing are sent back in the response - if len(leftover) > 0 { - res.Message.Content = leftover + toolCalls, content, err := toolParser.Add(r.Content) + if err == nil { + if len(content) > 0 { + res.Message.Content = content + fmt.Println("sending content in response", content) + } else if len(toolCalls) > 0 { + fmt.Println("sending tool calls in response", toolCalls) + res.Message.ToolCalls = toolCalls + res.Message.Content = "" + } else { + return } - // ! state is needed as we need to not match on the other states - case tools.ToolCallFound: - res.Message.ToolCalls = toolCalls - res.Message.Content = "" } } - - fmt.Println("sending response", res.Message.Content) - // * this is where we'd need the flag if we have a .handleState(tp.state) ch <- res }); err != nil { ch <- gin.H{"error": err.Error()} diff --git a/tools/tools.go b/tools/tools.go index 0e2f9dabf..1105848e9 100644 --- a/tools/tools.go +++ b/tools/tools.go @@ -6,91 +6,28 @@ import ( "fmt" "io" "log/slog" - "slices" "strings" gotmpl "text/template" - "text/template/parse" jsonv2 "github.com/go-json-experiment/json" jsontext "github.com/go-json-experiment/json/jsontext" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/template" ) -type State int - -// TODO: potentially coalesce states -const ( - SendTokens State = iota - GreedyToolWithPrefix - GreedyToolNoPrefix - ForceTools - ToolSuffix - ContainsPrefix - PartialPrefix - NotPartialPrefix - Done -) - -type ExternalState int - -const ( - ToolCallFound ExternalState = iota - ToolCallSendPartial - ToolCallAccumulate - ToolCallSendTokens -) - -func (s ExternalState) String() string { - switch s { - case ToolCallFound: - return "ToolCallFound" - case ToolCallSendPartial: - return "ToolCallSendPartial" - case ToolCallAccumulate: - return "ToolCallAccumulate" - case ToolCallSendTokens: - return "ToolCallSendTokens" - default: - return fmt.Sprintf("Unknown ExternalState (%d)", s) - } -} - -func (s State) String() string { - switch s { - case SendTokens: - return "SendTokens" - case GreedyToolWithPrefix: - return "GreedyToolWithPrefix" - case GreedyToolNoPrefix: - return "GreedyToolNoPrefix" - case ForceTools: - return "ForceTools" - case ToolSuffix: - return "ToolSuffix" - case PartialPrefix: - return "PossiblePrefix" - case Done: - return "Done" - case ContainsPrefix: - return "PartialPrefix" - default: - return fmt.Sprintf("Unknown State (%d)", s) - } -} - // TODO: simplify if possible type Parser struct { - tmpl *gotmpl.Template - state State - sb *strings.Builder - toolPrefix string - toolIndex int - ParserState ExternalState - Done bool + greedy bool + prefixFound bool + partialPrefix bool + tmpl *gotmpl.Template + sb *strings.Builder + prefix string + index int + Done bool } -// ? move to a separate file // parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls. // Returns parsed tool calls, a boolean indicating if the JSON is incomplete, and a boolean indicating if the tool calls were found func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) { @@ -222,314 +159,126 @@ func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) { return toolCalls, false, true } -// TODO: clean up the boundary of internal and external state transitions -func (p *Parser) updateStateAfterJSONParse(ok bool, partial bool, tcs []api.ToolCall) { - fmt.Printf("updating output state: ok=%v partial=%v tool_calls=%d current_state=%s\n", ok, partial, len(tcs), p.state) +// prefix stripped string if any, prefix found, and if we should accumulate +func (p *Parser) checkPrefix(s string) (string, bool, bool) { - // state transition logic - switch { - case !ok && !partial && p.state == ForceTools: - // force partial tool if we have a prefix - // no op and stay in force tools - p.sb.Reset() - case !ok && !partial: - if p.state == GreedyToolNoPrefix { - p.state = Done - // ? the output parser state is the same even though internal can we not leak the external state? - p.Done = true - } - if p.state == GreedyToolWithPrefix { - p.state = SendTokens - } - if p.state == PartialPrefix { - p.state = NotPartialPrefix - } - case !ok && partial: - // acucumulate - - case len(tcs) > 0: - // do not parse again in the greedy JSON case as soon as we have a tool call - p.sb.Reset() - } - p.updateExternalState(tcs) - fmt.Printf("state updated: new_state=%s parser_state=%s\n", p.state, p.ParserState) -} - -func (p *Parser) updateExternalState(tcs []api.ToolCall) { - fmt.Printf("updating external state: current_state=%s tool_calls=%d\n", p.state, len(tcs)) - - switch { - case len(tcs) > 0: - // do not parse again in the greedy JSON case as soon as we have a tool call - if p.state == GreedyToolWithPrefix { - p.state = SendTokens - } else if p.state == GreedyToolNoPrefix { - p.state = Done - p.Done = true - } - p.ParserState = ToolCallFound - case p.state == GreedyToolWithPrefix || p.state == GreedyToolNoPrefix || - p.state == ToolSuffix || p.state == PartialPrefix || - (p.state == ForceTools && len(tcs) == 0): - p.ParserState = ToolCallAccumulate - case p.state == ContainsPrefix: - p.ParserState = ToolCallSendPartial - case p.state == SendTokens || p.state == Done: - p.ParserState = ToolCallSendTokens - case p.state == NotPartialPrefix: - p.ParserState = ToolCallSendPartial - default: - p.ParserState = ToolCallSendTokens - p.sb.Reset() - p.state = SendTokens - } -} - -// string, and if it has a prefix -func (p *Parser) checkPrefix(s string) (string, bool) { - fmt.Printf("checking prefix: input=%s prefix=%s\n", s, p.toolPrefix) - - if p.toolPrefix == "" { - return s, true + if p.prefix == "" { + return s, false, true } original := s - s, hasPrefix := strings.CutPrefix(s, p.toolPrefix) + s = strings.TrimSpace(s) + s, hasPrefix := strings.CutPrefix(s, p.prefix) if hasPrefix { - p.state = ForceTools - fmt.Printf("found exact prefix match: remaining=%s\n", s) // partial tool possibly - accumulate - } else if suffixOverlap(s, p.toolPrefix) > 0 { - p.state = PartialPrefix - fmt.Printf("found partial prefix: remaining=%s\n", s) - return "", false - // the case where "token" - send "token" back + return s, true, true + } else if overlap := suffixOverlap(original, p.prefix); overlap > 0 { + // p.state = PartialPrefix + p.partialPrefix = true + return original[0 : len(original)-overlap], false, false + } else if idx := strings.Index(original, p.prefix); idx != -1 { + // Found prefix in middle of string, keep only content before prefix // accounts for spaces in prefix or suffix to avoid breaking cache - } else if strings.Contains(original, p.toolPrefix) { - idx := strings.Index(original, p.toolPrefix) - if idx != -1 { - // still keeps the prefix - p.state = ContainsPrefix - p.sb.Reset() - // todo: see if there is a simpler way for this - idx2 := strings.Index(s, p.toolPrefix) - // buffer now only has the prefix - p.sb.WriteString(s[idx2:]) - fmt.Printf("found prefix in middle: prefix_start=%d content_before=%s\n", idx, original[:idx]) - return original[:idx], false - } + p.partialPrefix = true + p.sb.Reset() + + p.sb.WriteString(strings.TrimSpace(original[idx:])) + return original[:idx], false, false } - return s, true + p.partialPrefix = false + return s, false, true } -// TODO: simplify the flow of this function -// ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing. -// Returns tool calls, whether parsing is incomplete, and any errors. -func (p *Parser) ParseToolCalls(s string) ([]api.ToolCall, string) { - fmt.Printf("parsing tool calls: input=%s current_state=%s\n", s, p.state) +func (p *Parser) Add(s string) (tools []api.ToolCall, content string, err error) { + slog.Debug("adding tool calls", "input", s) p.sb.WriteString(s) s = p.sb.String() - s = strings.TrimSpace(s) - if len(s) == 0 { - p.updateExternalState(nil) - return nil, "" + return nil, "", nil } - s, cont := p.checkPrefix(s) + s, prefixFound, cont := p.checkPrefix(s) + if !cont { - p.updateExternalState(nil) - if p.state == ContainsPrefix { - fmt.Printf("returning partial prefix: remaining=%s\n", s) - return nil, s + if s != "" { + // send only the content back, prefix exists + return nil, s, nil } - // * we'd be returning here for just accumulating with possible prefix - // * ext state is accumulation - return nil, "" + // accumulate case + return nil, "", nil } - // * lets say the check fails here and now we're still in external state accumulation here - // stay in SendTokens unless we have a prefix - if p.state == SendTokens { - p.updateExternalState(nil) + // circuit breaker + if prefixFound { + p.prefixFound = true + } + + // for cases with a prefix in template + if p.prefix != "" && !p.greedy && !p.prefixFound { + // send tokens down p.sb.Reset() - fmt.Printf("returning send tokens: remaining=%s\n", s) - return nil, s + return nil, "", errors.New("prefix not found") } - - // * we'd parse here as json to see if it's a tool call + // we have a prefix or are in json mode tcs, partial, ok := p.parseJSONToolCalls(s) - // * it would not be a tool call here - p.updateStateAfterJSONParse(ok, partial, tcs) - if !ok { - // * and so we should send the data here - // * we also need to move out of that internal state after sending the tokens - if p.state == NotPartialPrefix { - p.state = SendTokens - // the string would have acc until here - return nil, p.sb.String() - } - return nil, "" + if partial { + // accumulate case + return nil, "", nil } + + p.greedy = false + if !ok { + // will not be a partial at this point + p.sb.Reset() + // send tokens + if p.prefix == "" { + p.Done = true + } + if p.prefixFound { + // drop tokens instead - sb is reset, no tokens sent to user + return nil, "", nil + } + return nil, "", errors.New("failed to parse tool calls") + } + for _, tc := range tcs { - tc.Function.Index = p.toolIndex - p.toolIndex++ + tc.Function.Index = p.index + p.index++ } - fmt.Printf("finished parsing tool calls: tool_calls_found=%d\n", len(tcs)) - return tcs, "" + if p.prefix == "" { + p.Done = true + } + p.sb.Reset() + return tcs, "", nil } -func suffixOverlap(s, delim string) int { - max := min(len(delim), len(s)) - for i := max; i > 0; i-- { - if strings.HasSuffix(s, delim[:i]) { - return i - } +func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) { + parsedTemplate, err := template.Parse(templateToProcess.Root.String()) + if err != nil { + return nil, err } - return 0 -} - -// extractToolCallsTemplate finds the immediate following text after any IfNode containing ".ToolCalls" -func extractToolCallsTemplate(tmpl *gotmpl.Template) (string, bool) { - if tmpl == nil || tmpl.Tree == nil { - slog.Debug("TextAfterToolCalls: template or tree is nil") - return "", false + if parsedTemplate == nil { + return nil, errors.New("failed to parse template") } - var result string - var found bool - - var walk func(nodes []parse.Node) - walk = func(nodes []parse.Node) { - for _, node := range nodes { - if found { - return - } - - switch n := node.(type) { - case *parse.IfNode: - if nodeContainsToolCalls(n) { - // Collect immediate TextNode(s) at start of IfNode's list - var sb strings.Builder - for _, innerNode := range n.List.Nodes { - if tn, ok := innerNode.(*parse.TextNode); ok { - sb.Write(tn.Text) - } else { - // Stop at first non-text node - break - } - } - result = sb.String() - found = true - return - } - // Recurse into child nodes - walk(n.List.Nodes) - if n.ElseList != nil { - walk(n.ElseList.Nodes) - } - case *parse.ListNode: - walk(n.Nodes) - case *parse.RangeNode: - walk(n.List.Nodes) - if n.ElseList != nil { - walk(n.ElseList.Nodes) - } - case *parse.WithNode: - walk(n.List.Nodes) - if n.ElseList != nil { - walk(n.ElseList.Nodes) - } - default: - // Continue to next node - continue - } - - if found { - return - } - } + toolCallTemplate, hasToolCalls := toolTemplate(parsedTemplate) + if !hasToolCalls { + return nil, errors.New("failed to find tool template") + } + if toolCallTemplate == nil { + return nil, errors.New("failed to find tool template") } - walk(tmpl.Tree.Root.Nodes) - return result, found -} + toolPrefix, _ := ToolPrefix(templateToProcess) + toolPrefix = strings.TrimSpace(toolPrefix) -// Helper to detect if a node's condition includes ".ToolCalls" -func nodeContainsToolCalls(n *parse.IfNode) bool { - for _, cmd := range n.Pipe.Cmds { - for _, arg := range cmd.Args { - if field, ok := arg.(*parse.FieldNode); ok { - if slices.Contains(field.Ident, "ToolCalls") { - return true - } - } - } - } - return false -} - -func ToolPrefix(tmpl *gotmpl.Template) (string, bool) { - tokenText, ok := extractToolCallsTemplate(tmpl) - if !ok { - return "", false - } - tokenText = strings.TrimSpace(tokenText) - if tokenText == "" { - return "", false - } - first := strings.Fields(tokenText)[0] - - start := -1 - end := -1 - for i, r := range tokenText { - if r == '<' || r == '[' { - start = i - } - if (r == '>' || r == ']') && start != -1 { - end = i - break - } - } - if start != -1 && end != -1 { - // return the token including the [ or < and the ] or > - return tokenText[start : end+1], true - } else if start != -1 { - // get until the [ or < - in the case tag was not closed - return tokenText[:start], true - } else if end != -1 { - // get after the ] or > - in the case tag was not opened - return tokenText[end+1:], true - } - return first, true -} - -func NewParser(tmpl *gotmpl.Template, toolTemplate *gotmpl.Template) *Parser { - // TODO: use new template parsing to get all tokens for the prefix - if tmpl == nil { - return nil - } - if toolTemplate == nil { - return nil - } - - prefix, _ := ToolPrefix(tmpl) - prefix = strings.TrimSpace(prefix) - - var state State - if prefix == "" { - state = GreedyToolNoPrefix - } else { - state = GreedyToolWithPrefix - } - fmt.Printf("creating new tool parser: prefix=%s initial_state=%s\n", prefix, state) + fmt.Printf("creating new tool parser: prefix=%s\n", toolPrefix) return &Parser{ - tmpl: toolTemplate, - sb: &strings.Builder{}, - toolPrefix: prefix, - state: state, - ParserState: ToolCallAccumulate, - } + tmpl: toolCallTemplate, + sb: &strings.Builder{}, + prefix: toolPrefix, + greedy: true, + }, nil } diff --git a/tools/tools_test.go b/tools/tools_test.go index 71ed88755..597d0bdd1 100644 --- a/tools/tools_test.go +++ b/tools/tools_test.go @@ -239,14 +239,14 @@ func TestParseToolCalls(t *testing.T) { model: "qwen3", output: `< fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, expectedToolCall: []api.ToolCall{}, - expectedTokens: `< fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + expectedTokens: ` fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, }, { name: "qwen3 invalid tool call with partial tool prefix (multiple rune suffix match)", model: "qwen3", output: ``, expectedToolCall: []api.ToolCall{}, - expectedTokens: ``, + expectedTokens: ` fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, }, { name: "qwen3 invalid tool call with malformed tool prefix", @@ -315,34 +315,31 @@ func TestParseToolCalls(t *testing.T) { t.Run("parse", func(t *testing.T) { // fmt.Printf("tmpl: %s\n", tmpl.Root.String()) - toolTemplate, ok := toolTemplateHelper(t, tmpl) - if !ok { - t.Fatalf("tool template not found for model %s", tt.model) + tp, err := NewParser(tmpl.Template) + if err != nil { + t.Fatal(err) } - tp := NewParser(tmpl.Template, toolTemplate) got := []api.ToolCall{} var gotTokens strings.Builder + var add bool tokens := strings.Fields(tt.output) for _, tok := range tokens { - add := true s := " " + tok + add = true if !tp.Done { - toolCalls, leftover := tp.ParseToolCalls(s) - switch tp.ParserState { - case ToolCallFound: - got = append(got, toolCalls...) - add = false - case ToolCallSendTokens: - gotTokens.WriteString(s) - add = false - case ToolCallAccumulate: - add = false - case ToolCallSendPartial: - t.Log("send partial", "leftover", leftover) - gotTokens.WriteString(" " + leftover) - add = false + toolCalls, content, err := tp.Add(s) + if err == nil { + if content != "" { + gotTokens.WriteString(content) + add = false + } else if len(toolCalls) > 0 { + got = append(got, toolCalls...) + add = false + } else { + add = false + } } } if add { diff --git a/tools/utils.go b/tools/utils.go new file mode 100644 index 000000000..6e37b9010 --- /dev/null +++ b/tools/utils.go @@ -0,0 +1,155 @@ +package tools + +import ( + "log/slog" + "slices" + "strings" + gotmpl "text/template" + "text/template/parse" + + "github.com/ollama/ollama/template" +) + +// extractToolCallsTemplate finds the immediate following text after any IfNode containing ".ToolCalls" +func extractToolCallsTemplate(tmpl *gotmpl.Template) (string, bool) { + if tmpl == nil || tmpl.Tree == nil { + slog.Debug("TextAfterToolCalls: template or tree is nil") + return "", false + } + + var result string + var found bool + + var walk func(nodes []parse.Node) + walk = func(nodes []parse.Node) { + for _, node := range nodes { + if found { + return + } + + switch n := node.(type) { + case *parse.IfNode: + if nodeContainsToolCalls(n) { + // Collect immediate TextNode(s) at start of IfNode's list + var sb strings.Builder + for _, innerNode := range n.List.Nodes { + if tn, ok := innerNode.(*parse.TextNode); ok { + sb.Write(tn.Text) + } else { + // Stop at first non-text node + break + } + } + result = sb.String() + found = true + return + } + // Recurse into child nodes + walk(n.List.Nodes) + if n.ElseList != nil { + walk(n.ElseList.Nodes) + } + case *parse.ListNode: + walk(n.Nodes) + case *parse.RangeNode: + walk(n.List.Nodes) + if n.ElseList != nil { + walk(n.ElseList.Nodes) + } + case *parse.WithNode: + walk(n.List.Nodes) + if n.ElseList != nil { + walk(n.ElseList.Nodes) + } + default: + // Continue to next node + continue + } + + if found { + return + } + } + } + + walk(tmpl.Tree.Root.Nodes) + return result, found +} + +// Helper to detect if a node's condition includes ".ToolCalls" +func nodeContainsToolCalls(n *parse.IfNode) bool { + for _, cmd := range n.Pipe.Cmds { + for _, arg := range cmd.Args { + if field, ok := arg.(*parse.FieldNode); ok { + if slices.Contains(field.Ident, "ToolCalls") { + return true + } + } + } + } + return false +} + +// ToolPrefix returns the prefix for the tool call if it exists +// TODO(parthsareen): get full prefix from the template instead of just the first token +func ToolPrefix(tmpl *gotmpl.Template) (string, bool) { + tokenText, ok := extractToolCallsTemplate(tmpl) + if !ok { + return "", false + } + tokenText = strings.TrimSpace(tokenText) + if tokenText == "" { + return "", false + } + first := strings.Fields(tokenText)[0] + + start := -1 + end := -1 + for i, r := range tokenText { + if r == '<' || r == '[' { + start = i + } + if (r == '>' || r == ']') && start != -1 { + end = i + break + } + } + if start != -1 && end != -1 { + // return the token including the [ or < and the ] or > + return tokenText[start : end+1], true + } else if start != -1 { + // get until the [ or < - in the case tag was not closed + return tokenText[:start], true + } else if end != -1 { + // get after the ] or > - in the case tag was not opened + return tokenText[end+1:], true + } + return first, true +} + +func toolTemplate(t *template.Template) (*gotmpl.Template, bool) { + // create a subtree from the node that ranges over .ToolCalls + tmpl := t.Subtree(func(n parse.Node) bool { + if t, ok := n.(*parse.RangeNode); ok { + return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls") + } + + return false + }) + + if tmpl == nil { + return nil, false + } + + return tmpl, true +} + +func suffixOverlap(s, delim string) int { + max := min(len(delim), len(s)) + for i := max; i > 0; i-- { + if strings.HasSuffix(s, delim[:i]) { + return i + } + } + return 0 +} diff --git a/tools/utils_test.go b/tools/utils_test.go new file mode 100644 index 000000000..4c37ecd40 --- /dev/null +++ b/tools/utils_test.go @@ -0,0 +1,185 @@ +package tools + +import ( + "testing" + gotmpl "text/template" +) + +func TestToolPrefix(t *testing.T) { + cases := []struct { + name string + template string + want string + ok bool + }{ + { + name: "basic tool call with action prefix", + template: "{{if .ToolCalls}}Action: ```json{{end}}", + want: "Action:", + ok: true, + }, + { + name: "incomplete functools bracket", + template: "{{if .ToolCalls}}functools[{{end}}", + want: "functools", + ok: true, + }, + { + name: "tool call with angle brackets", + template: "{{if .ToolCalls}}Hello, world! {{end}}", + want: "", + ok: true, + }, + { + name: "multiple tool call formats", + template: "{{if .ToolCalls}}[tool_call] {{end}}", + want: "[tool_call]", + ok: true, + }, + { + name: "single angle bracket tool call", + template: "{{if .ToolCalls}}{{end}}", + want: "", + ok: true, + }, + { + name: "incomplete angle bracket after tool call", + template: "{{if .ToolCalls}}[tool_call] <{{end}}", + want: "[tool_call]", + ok: true, + }, + { + name: "angle bracket prefix with tool call", + template: "{{if .ToolCalls}}> {{end}}", + want: "", + ok: true, + }, + { + name: "uppercase tool call with incomplete bracket", + template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}", + want: "[TOOL_CALL]", + ok: true, + }, + { + name: "uppercase tool call with adjacent bracket", + template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}", + want: "[TOOL_CALL]", + ok: true, + }, + { + name: "tool call with pipe delimiters", + template: "{{if .ToolCalls}}<|tool_call|>{{end}}", + want: "<|tool_call|>", + ok: true, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + tmpl, err := gotmpl.New("test").Parse(tt.template) + if err != nil { + t.Fatalf("failed to parse template: %v", err) + } + got, ok := ToolPrefix(tmpl) + if got != tt.want { + t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want) + } + if ok != tt.ok { + t.Errorf("ToolToken(%q) = %v; want %v", tt.template, ok, tt.ok) + } + }) + } +} + +func TestTextAfterToolCalls(t *testing.T) { + cases := []struct { + name string + template string + want string + ok bool + }{ + { + name: "basic tool call with text after", + template: `{{if .ToolCalls}}tool response{{end}}`, + want: "tool response", + ok: true, + }, + { + name: "tool call with mixed content after", + template: `{{if .ToolCalls}}{{.Something}}{{end}}`, + want: "", + ok: true, + }, + { + name: "tool call with no text after", + template: `{{if .ToolCalls}}{{.Something}}{{end}}`, + want: "", + ok: true, + }, + { + name: "nested tool call", + template: `{{if .Something}}{{if .ToolCalls}}[TOOL_CALL]{{end}}{{end}}`, + want: "[TOOL_CALL]", + ok: true, + }, + { + name: "no tool calls", + template: `{{if .Something}}no tools here{{end}}`, + want: "", + ok: false, + }, + { + name: "empty template", + template: ``, + want: "", + ok: false, + }, + { + name: "multiple tool calls sections", + template: `{{if .ToolCalls}}first{{end}}{{if .ToolCalls}}second{{end}}`, + want: "first", + ok: true, + }, + { + name: "range over tool calls", + template: `{{if .ToolCalls}}{{range .ToolCalls}}tool{{end}}{{end}}`, + want: "", + ok: true, + }, + { + name: "tool calls with pipe delimiters", + template: `{{if .ToolCalls}}<|tool|>{{end}}`, + want: "<|tool|>", + ok: true, + }, + { + name: "tool calls with nested template", + template: `{{if .ToolCalls}}{{template "tool" .}}{{end}}`, + want: "", + ok: true, + }, + { + name: "tool calls with whitespace variations", + template: `{{if .ToolCalls}} tool {{end}}`, + want: " tool ", + ok: true, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + tmpl, err := gotmpl.New("test").Parse(tt.template) + if err != nil { + t.Fatalf("failed to parse template: %v", err) + } + + got, ok := extractToolCallsTemplate(tmpl) + if got != tt.want { + t.Errorf("TextAfterToolCalls() got = %q, want %q", got, tt.want) + } + if ok != tt.ok { + t.Errorf("TextAfterToolCalls() ok = %v, want %v", ok, tt.ok) + } + }) + } +}