From 4059b8db015c8ac6fb1b03984c0138a6c11a2b49 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Mon, 12 May 2025 14:07:59 -0700 Subject: [PATCH] renaming and splitting stuff up --- server/model.go | 125 -------------- server/routes.go | 19 ++- .../testdata}/command-r-plus.gotmpl | 0 .../testdata}/command-r-plus.out | 0 .../testdata}/firefunction.gotmpl | 0 .../tools => tools/testdata}/firefunction.out | 0 .../testdata}/llama3-groq-tool-use.gotmpl | 0 .../testdata}/llama3-groq-tool-use.out | 0 .../tools => tools/testdata}/llama3.2.gotmpl | 0 .../tools => tools/testdata}/llama3.2.out | 0 .../tools => tools/testdata}/messages.json | 0 .../tools => tools/testdata}/mistral.gotmpl | 0 .../tools => tools/testdata}/mistral.out | 0 .../tools => tools/testdata}/nemotron.gotmpl | 0 .../tools => tools/testdata}/nemotron.out | 0 .../testdata}/qwen2.5-coder.gotmpl | 0 .../testdata}/qwen2.5-coder.out | 0 .../tools => tools/testdata}/qwen3.gotmpl | 0 .../tools => tools/testdata}/qwen3.out | 0 .../tools => tools/testdata}/tools.json | 0 .../tools => tools/testdata}/xlam.gotmpl | 0 .../tools => tools/testdata}/xlam.out | 0 {server => tools}/tools.go | 155 ++++++++++++++++-- {server => tools}/tools_test.go | 37 ++++- 24 files changed, 184 insertions(+), 152 deletions(-) rename {server/testdata/tools => tools/testdata}/command-r-plus.gotmpl (100%) rename {server/testdata/tools => tools/testdata}/command-r-plus.out (100%) rename {server/testdata/tools => tools/testdata}/firefunction.gotmpl (100%) rename {server/testdata/tools => tools/testdata}/firefunction.out (100%) rename {server/testdata/tools => tools/testdata}/llama3-groq-tool-use.gotmpl (100%) rename {server/testdata/tools => tools/testdata}/llama3-groq-tool-use.out (100%) rename {server/testdata/tools => tools/testdata}/llama3.2.gotmpl (100%) rename {server/testdata/tools => tools/testdata}/llama3.2.out (100%) rename {server/testdata/tools => tools/testdata}/messages.json (100%) rename {server/testdata/tools => tools/testdata}/mistral.gotmpl (100%) rename {server/testdata/tools => tools/testdata}/mistral.out (100%) rename {server/testdata/tools => tools/testdata}/nemotron.gotmpl (100%) rename {server/testdata/tools => tools/testdata}/nemotron.out (100%) rename {server/testdata/tools => tools/testdata}/qwen2.5-coder.gotmpl (100%) rename {server/testdata/tools => tools/testdata}/qwen2.5-coder.out (100%) rename {server/testdata/tools => tools/testdata}/qwen3.gotmpl (100%) rename {server/testdata/tools => tools/testdata}/qwen3.out (100%) rename {server/testdata/tools => tools/testdata}/tools.json (100%) rename {server/testdata/tools => tools/testdata}/xlam.gotmpl (100%) rename {server/testdata/tools => tools/testdata}/xlam.out (100%) rename {server => tools}/tools.go (75%) rename {server => tools}/tools_test.go (95%) diff --git a/server/model.go b/server/model.go index eb28d3733..7e749829c 100644 --- a/server/model.go +++ b/server/model.go @@ -11,7 +11,6 @@ import ( "net/http" "os" "slices" - "strings" gotmpl "text/template" "text/template/parse" @@ -130,130 +129,6 @@ func detectContentType(r io.Reader) (string, error) { return "unknown", nil } -// extractToolCallsTemplate finds the immediate following text after any IfNode containing ".ToolCalls" -func extractToolCallsTemplate(tmpl *gotmpl.Template) (string, bool) { - if tmpl == nil || tmpl.Tree == nil { - slog.Debug("TextAfterToolCalls: template or tree is nil") - return "", false - } - - var result string - var found bool - - var walk func(nodes []parse.Node) - walk = func(nodes []parse.Node) { - for _, node := range nodes { - if found { - return - } - - switch n := node.(type) { - case *parse.IfNode: - if nodeContainsToolCalls(n) { - // Collect immediate TextNode(s) at start of IfNode's list - var sb strings.Builder - for _, innerNode := range n.List.Nodes { - if tn, ok := innerNode.(*parse.TextNode); ok { - sb.Write(tn.Text) - } else { - // Stop at first non-text node - break - } - } - result = sb.String() - found = true - return - } - // Recurse into child nodes - walk(n.List.Nodes) - if n.ElseList != nil { - walk(n.ElseList.Nodes) - } - case *parse.ListNode: - walk(n.Nodes) - case *parse.RangeNode: - walk(n.List.Nodes) - if n.ElseList != nil { - walk(n.ElseList.Nodes) - } - case *parse.WithNode: - walk(n.List.Nodes) - if n.ElseList != nil { - walk(n.ElseList.Nodes) - } - default: - // Continue to next node - continue - } - - if found { - return - } - } - } - - walk(tmpl.Tree.Root.Nodes) - return result, found -} - -// Helper to detect if a node's condition includes ".ToolCalls" -func nodeContainsToolCalls(n *parse.IfNode) bool { - for _, cmd := range n.Pipe.Cmds { - for _, arg := range cmd.Args { - if field, ok := arg.(*parse.FieldNode); ok { - if slices.Contains(field.Ident, "ToolCalls") { - return true - } - } - } - } - return false -} - -func ToolPrefix2(tmpl *gotmpl.Template) (string, bool) { - tokenText, ok := extractToolCallsTemplate(tmpl) - if !ok { - return "", false - } - tokenText = strings.TrimSpace(tokenText) - return tokenText, true -} - -func ToolPrefix(tmpl *gotmpl.Template) (string, bool) { - tokenText, ok := extractToolCallsTemplate(tmpl) - if !ok { - return "", false - } - tokenText = strings.TrimSpace(tokenText) - if tokenText == "" { - return "", false - } - first := strings.Fields(tokenText)[0] - - start := -1 - end := -1 - for i, r := range tokenText { - if r == '<' || r == '[' { - start = i - } - if (r == '>' || r == ']') && start != -1 { - end = i - break - } - } - if start != -1 && end != -1 { - // return the token including the [ or < and the ] or > - return tokenText[start : end+1], true - } else if start != -1 { - // get until the [ or < - in the case tag was not closed - return tokenText[:start], true - } else if end != -1 { - // get after the ] or > - in the case tag was not opened - return tokenText[end+1:], true - } - return first, true -} - func ToolTemplate(m *Model) (*gotmpl.Template, bool) { // create a subtree from the node that ranges over .ToolCalls tmpl := m.Template.Subtree(func(n parse.Node) bool { diff --git a/server/routes.go b/server/routes.go index daf8017f7..678af53f6 100644 --- a/server/routes.go +++ b/server/routes.go @@ -38,6 +38,7 @@ import ( "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/registry" "github.com/ollama/ollama/template" + "github.com/ollama/ollama/tools" "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" @@ -1485,9 +1486,15 @@ func (s *Server) ChatHandler(c *gin.Context) { ch := make(chan any) go func() { defer close(ch) - var toolParser *ToolParser + // ! personally not a fan of this pattern + toolTemplate, ok := ToolTemplate(m) + if !ok { + slog.Error("tool template not found", "model", m.Name) + return + } + var toolParser *tools.Parser if len(req.Tools) > 0 { - toolParser = NewToolParser(m) + toolParser = tools.NewParser(m.Template.Template, toolTemplate) } if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ @@ -1521,18 +1528,18 @@ func (s *Server) ChatHandler(c *gin.Context) { // * However, we'd need a flag to indicate whether to send the response or not // * happy to take whatever is more idiomatic switch toolParser.ParserState { - case ToolCallAccumulate: + case tools.ToolCallAccumulate: // tokens are accumulated in the tool parser return - case ToolCallSendTokens: + case tools.ToolCallSendTokens: // tokens are sent back in the response - case ToolCallSendPartial: + case tools.ToolCallSendPartial: // tokens not needed for parsing are sent back in the response if len(leftover) > 0 { res.Message.Content = leftover } // ! state is needed as we need to not match on the other states - case ToolCallFound: + case tools.ToolCallFound: res.Message.ToolCalls = toolCalls res.Message.Content = "" } diff --git a/server/testdata/tools/command-r-plus.gotmpl b/tools/testdata/command-r-plus.gotmpl similarity index 100% rename from server/testdata/tools/command-r-plus.gotmpl rename to tools/testdata/command-r-plus.gotmpl diff --git a/server/testdata/tools/command-r-plus.out b/tools/testdata/command-r-plus.out similarity index 100% rename from server/testdata/tools/command-r-plus.out rename to tools/testdata/command-r-plus.out diff --git a/server/testdata/tools/firefunction.gotmpl b/tools/testdata/firefunction.gotmpl similarity index 100% rename from server/testdata/tools/firefunction.gotmpl rename to tools/testdata/firefunction.gotmpl diff --git a/server/testdata/tools/firefunction.out b/tools/testdata/firefunction.out similarity index 100% rename from server/testdata/tools/firefunction.out rename to tools/testdata/firefunction.out diff --git a/server/testdata/tools/llama3-groq-tool-use.gotmpl b/tools/testdata/llama3-groq-tool-use.gotmpl similarity index 100% rename from server/testdata/tools/llama3-groq-tool-use.gotmpl rename to tools/testdata/llama3-groq-tool-use.gotmpl diff --git a/server/testdata/tools/llama3-groq-tool-use.out b/tools/testdata/llama3-groq-tool-use.out similarity index 100% rename from server/testdata/tools/llama3-groq-tool-use.out rename to tools/testdata/llama3-groq-tool-use.out diff --git a/server/testdata/tools/llama3.2.gotmpl b/tools/testdata/llama3.2.gotmpl similarity index 100% rename from server/testdata/tools/llama3.2.gotmpl rename to tools/testdata/llama3.2.gotmpl diff --git a/server/testdata/tools/llama3.2.out b/tools/testdata/llama3.2.out similarity index 100% rename from server/testdata/tools/llama3.2.out rename to tools/testdata/llama3.2.out diff --git a/server/testdata/tools/messages.json b/tools/testdata/messages.json similarity index 100% rename from server/testdata/tools/messages.json rename to tools/testdata/messages.json diff --git a/server/testdata/tools/mistral.gotmpl b/tools/testdata/mistral.gotmpl similarity index 100% rename from server/testdata/tools/mistral.gotmpl rename to tools/testdata/mistral.gotmpl diff --git a/server/testdata/tools/mistral.out b/tools/testdata/mistral.out similarity index 100% rename from server/testdata/tools/mistral.out rename to tools/testdata/mistral.out diff --git a/server/testdata/tools/nemotron.gotmpl b/tools/testdata/nemotron.gotmpl similarity index 100% rename from server/testdata/tools/nemotron.gotmpl rename to tools/testdata/nemotron.gotmpl diff --git a/server/testdata/tools/nemotron.out b/tools/testdata/nemotron.out similarity index 100% rename from server/testdata/tools/nemotron.out rename to tools/testdata/nemotron.out diff --git a/server/testdata/tools/qwen2.5-coder.gotmpl b/tools/testdata/qwen2.5-coder.gotmpl similarity index 100% rename from server/testdata/tools/qwen2.5-coder.gotmpl rename to tools/testdata/qwen2.5-coder.gotmpl diff --git a/server/testdata/tools/qwen2.5-coder.out b/tools/testdata/qwen2.5-coder.out similarity index 100% rename from server/testdata/tools/qwen2.5-coder.out rename to tools/testdata/qwen2.5-coder.out diff --git a/server/testdata/tools/qwen3.gotmpl b/tools/testdata/qwen3.gotmpl similarity index 100% rename from server/testdata/tools/qwen3.gotmpl rename to tools/testdata/qwen3.gotmpl diff --git a/server/testdata/tools/qwen3.out b/tools/testdata/qwen3.out similarity index 100% rename from server/testdata/tools/qwen3.out rename to tools/testdata/qwen3.out diff --git a/server/testdata/tools/tools.json b/tools/testdata/tools.json similarity index 100% rename from server/testdata/tools/tools.json rename to tools/testdata/tools.json diff --git a/server/testdata/tools/xlam.gotmpl b/tools/testdata/xlam.gotmpl similarity index 100% rename from server/testdata/tools/xlam.gotmpl rename to tools/testdata/xlam.gotmpl diff --git a/server/testdata/tools/xlam.out b/tools/testdata/xlam.out similarity index 100% rename from server/testdata/tools/xlam.out rename to tools/testdata/xlam.out diff --git a/server/tools.go b/tools/tools.go similarity index 75% rename from server/tools.go rename to tools/tools.go index 3bb870b55..0e2f9dabf 100644 --- a/server/tools.go +++ b/tools/tools.go @@ -1,12 +1,15 @@ -package server +package tools import ( "bytes" "errors" "fmt" "io" + "log/slog" + "slices" "strings" gotmpl "text/template" + "text/template/parse" jsonv2 "github.com/go-json-experiment/json" jsontext "github.com/go-json-experiment/json/jsontext" @@ -77,7 +80,7 @@ func (s State) String() string { } // TODO: simplify if possible -type ToolParser struct { +type Parser struct { tmpl *gotmpl.Template state State sb *strings.Builder @@ -90,7 +93,7 @@ type ToolParser struct { // ? move to a separate file // parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls. // Returns parsed tool calls, a boolean indicating if the JSON is incomplete, and a boolean indicating if the tool calls were found -func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) { +func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) { fmt.Printf("attempting to parse JSON tool calls: input=%s\n", s) var b bytes.Buffer @@ -220,7 +223,7 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) { } // TODO: clean up the boundary of internal and external state transitions -func (p *ToolParser) updateStateAfterJSONParse(ok bool, partial bool, tcs []api.ToolCall) { +func (p *Parser) updateStateAfterJSONParse(ok bool, partial bool, tcs []api.ToolCall) { fmt.Printf("updating output state: ok=%v partial=%v tool_calls=%d current_state=%s\n", ok, partial, len(tcs), p.state) // state transition logic @@ -252,7 +255,7 @@ func (p *ToolParser) updateStateAfterJSONParse(ok bool, partial bool, tcs []api. fmt.Printf("state updated: new_state=%s parser_state=%s\n", p.state, p.ParserState) } -func (p *ToolParser) updateExternalState(tcs []api.ToolCall) { +func (p *Parser) updateExternalState(tcs []api.ToolCall) { fmt.Printf("updating external state: current_state=%s tool_calls=%d\n", p.state, len(tcs)) switch { @@ -283,7 +286,7 @@ func (p *ToolParser) updateExternalState(tcs []api.ToolCall) { } // string, and if it has a prefix -func (p *ToolParser) checkPrefix(s string) (string, bool) { +func (p *Parser) checkPrefix(s string) (string, bool) { fmt.Printf("checking prefix: input=%s prefix=%s\n", s, p.toolPrefix) if p.toolPrefix == "" { @@ -322,7 +325,7 @@ func (p *ToolParser) checkPrefix(s string) (string, bool) { // TODO: simplify the flow of this function // ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing. // Returns tool calls, whether parsing is incomplete, and any errors. -func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, string) { +func (p *Parser) ParseToolCalls(s string) ([]api.ToolCall, string) { fmt.Printf("parsing tool calls: input=%s current_state=%s\n", s, p.state) p.sb.WriteString(s) @@ -388,26 +391,144 @@ func suffixOverlap(s, delim string) int { return 0 } -func NewToolParser(model *Model) *ToolParser { - // TODO: use new template parsing to get all tokens for the prefix - templateToolPrefix, _ := ToolPrefix(model.Template.Template) - templateToolPrefix = strings.TrimSpace(templateToolPrefix) - tmpl, ok := ToolTemplate(model) +// extractToolCallsTemplate finds the immediate following text after any IfNode containing ".ToolCalls" +func extractToolCallsTemplate(tmpl *gotmpl.Template) (string, bool) { + if tmpl == nil || tmpl.Tree == nil { + slog.Debug("TextAfterToolCalls: template or tree is nil") + return "", false + } + + var result string + var found bool + + var walk func(nodes []parse.Node) + walk = func(nodes []parse.Node) { + for _, node := range nodes { + if found { + return + } + + switch n := node.(type) { + case *parse.IfNode: + if nodeContainsToolCalls(n) { + // Collect immediate TextNode(s) at start of IfNode's list + var sb strings.Builder + for _, innerNode := range n.List.Nodes { + if tn, ok := innerNode.(*parse.TextNode); ok { + sb.Write(tn.Text) + } else { + // Stop at first non-text node + break + } + } + result = sb.String() + found = true + return + } + // Recurse into child nodes + walk(n.List.Nodes) + if n.ElseList != nil { + walk(n.ElseList.Nodes) + } + case *parse.ListNode: + walk(n.Nodes) + case *parse.RangeNode: + walk(n.List.Nodes) + if n.ElseList != nil { + walk(n.ElseList.Nodes) + } + case *parse.WithNode: + walk(n.List.Nodes) + if n.ElseList != nil { + walk(n.ElseList.Nodes) + } + default: + // Continue to next node + continue + } + + if found { + return + } + } + } + + walk(tmpl.Tree.Root.Nodes) + return result, found +} + +// Helper to detect if a node's condition includes ".ToolCalls" +func nodeContainsToolCalls(n *parse.IfNode) bool { + for _, cmd := range n.Pipe.Cmds { + for _, arg := range cmd.Args { + if field, ok := arg.(*parse.FieldNode); ok { + if slices.Contains(field.Ident, "ToolCalls") { + return true + } + } + } + } + return false +} + +func ToolPrefix(tmpl *gotmpl.Template) (string, bool) { + tokenText, ok := extractToolCallsTemplate(tmpl) if !ok { + return "", false + } + tokenText = strings.TrimSpace(tokenText) + if tokenText == "" { + return "", false + } + first := strings.Fields(tokenText)[0] + + start := -1 + end := -1 + for i, r := range tokenText { + if r == '<' || r == '[' { + start = i + } + if (r == '>' || r == ']') && start != -1 { + end = i + break + } + } + if start != -1 && end != -1 { + // return the token including the [ or < and the ] or > + return tokenText[start : end+1], true + } else if start != -1 { + // get until the [ or < - in the case tag was not closed + return tokenText[:start], true + } else if end != -1 { + // get after the ] or > - in the case tag was not opened + return tokenText[end+1:], true + } + return first, true +} + +func NewParser(tmpl *gotmpl.Template, toolTemplate *gotmpl.Template) *Parser { + // TODO: use new template parsing to get all tokens for the prefix + if tmpl == nil { + return nil + } + if toolTemplate == nil { return nil } + prefix, _ := ToolPrefix(tmpl) + prefix = strings.TrimSpace(prefix) + var state State - if templateToolPrefix == "" { + if prefix == "" { state = GreedyToolNoPrefix } else { state = GreedyToolWithPrefix } - fmt.Printf("creating new tool parser: prefix=%s initial_state=%s\n", templateToolPrefix, state) - return &ToolParser{ - tmpl: tmpl, + fmt.Printf("creating new tool parser: prefix=%s initial_state=%s\n", prefix, state) + return &Parser{ + tmpl: toolTemplate, sb: &strings.Builder{}, - toolPrefix: templateToolPrefix, + toolPrefix: prefix, state: state, ParserState: ToolCallAccumulate, } diff --git a/server/tools_test.go b/tools/tools_test.go similarity index 95% rename from server/tools_test.go rename to tools/tools_test.go index 6ec5712e3..71ed88755 100644 --- a/server/tools_test.go +++ b/tools/tools_test.go @@ -1,4 +1,4 @@ -package server +package tools import ( "bytes" @@ -6,8 +6,11 @@ import ( "fmt" "os" "path/filepath" + "slices" "strings" "testing" + gotmpl "text/template" + "text/template/parse" "github.com/google/go-cmp/cmp" @@ -27,7 +30,7 @@ func readFile(t *testing.T, base, name string) *bytes.Buffer { } func TestParseToolCalls(t *testing.T) { - p := filepath.Join("testdata", "tools") + p := filepath.Join("testdata") t1 := api.ToolCall{ Function: api.ToolCallFunction{ Name: "get_current_weather", @@ -311,8 +314,12 @@ func TestParseToolCalls(t *testing.T) { }) t.Run("parse", func(t *testing.T) { - m := &Model{Template: tmpl} - tp := NewToolParser(m) + // fmt.Printf("tmpl: %s\n", tmpl.Root.String()) + toolTemplate, ok := toolTemplateHelper(t, tmpl) + if !ok { + t.Fatalf("tool template not found for model %s", tt.model) + } + tp := NewParser(tmpl.Template, toolTemplate) got := []api.ToolCall{} var gotTokens strings.Builder @@ -358,3 +365,25 @@ func TestParseToolCalls(t *testing.T) { }) } } + +func toolTemplateHelper(t *testing.T, tmpl *template.Template) (*gotmpl.Template, bool) { + // create a subtree from the node that ranges over .ToolCalls + + tmpl2 := tmpl.Subtree(func(n parse.Node) bool { + if t, ok := n.(*parse.RangeNode); ok { + return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls") + } + + return false + }) + + if tmpl2.Root != nil { + t.Log("tmpl2", tmpl2.Root.String()) + } + + if tmpl2 == nil { + return nil, false + } + + return tmpl2, true +}