diff --git a/go.mod b/go.mod
index 283286b7d..d9de611ba 100644
--- a/go.mod
+++ b/go.mod
@@ -35,6 +35,7 @@ require (
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
+ github.com/go-json-experiment/json v0.0.0-20250417205406-170dfdcf87d1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/flatbuffers v24.3.25+incompatible // indirect
github.com/kr/text v0.2.0 // indirect
diff --git a/go.sum b/go.sum
index 5755616f6..780a76f10 100644
--- a/go.sum
+++ b/go.sum
@@ -69,6 +69,8 @@ github.com/go-fonts/latin-modern v0.2.0/go.mod h1:rQVLdDMK+mK1xscDwsqM5J8U2jrRa3
github.com/go-fonts/liberation v0.1.1/go.mod h1:K6qoJYypsmfVjWg8KOVDQhLc8UDgIK2HYqyqAO9z7GY=
github.com/go-fonts/stix v0.1.0/go.mod h1:w/c1f0ldAUlJmLBvlbkvVXLAD+tAMqobIIQpmnUIzUY=
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
+github.com/go-json-experiment/json v0.0.0-20250417205406-170dfdcf87d1 h1:+VexzzkMLb1tnvpuQdGT/DicIRW7MN8ozsXqBMgp0Hk=
+github.com/go-json-experiment/json v0.0.0-20250417205406-170dfdcf87d1/go.mod h1:TiCD2a1pcmjd7YnhGH0f/zKNcCD06B029pHhzV23c2M=
github.com/go-latex/latex v0.0.0-20210118124228-b3d85cf34e07/go.mod h1:CO1AlKB2CSIqUrmQPqA0gdRIlnLEY0gK5JGjh37zN5U=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
diff --git a/server/model.go b/server/model.go
index b5c91ef1e..eb28d3733 100644
--- a/server/model.go
+++ b/server/model.go
@@ -210,7 +210,16 @@ func nodeContainsToolCalls(n *parse.IfNode) bool {
return false
}
-func ToolToken(tmpl *gotmpl.Template) (string, bool) {
+func ToolPrefix2(tmpl *gotmpl.Template) (string, bool) {
+ tokenText, ok := extractToolCallsTemplate(tmpl)
+ if !ok {
+ return "", false
+ }
+ tokenText = strings.TrimSpace(tokenText)
+ return tokenText, true
+}
+
+func ToolPrefix(tmpl *gotmpl.Template) (string, bool) {
tokenText, ok := extractToolCallsTemplate(tmpl)
if !ok {
return "", false
diff --git a/server/model_test.go b/server/model_test.go
index 498fdb408..8fd19d2db 100644
--- a/server/model_test.go
+++ b/server/model_test.go
@@ -80,7 +80,7 @@ func TestToolToken(t *testing.T) {
if err != nil {
t.Fatalf("failed to parse template: %v", err)
}
- got, ok := ToolToken(tmpl)
+ got, ok := ToolPrefix(tmpl)
if got != tt.want {
t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want)
}
diff --git a/server/routes.go b/server/routes.go
index c40ab211c..950185cc6 100644
--- a/server/routes.go
+++ b/server/routes.go
@@ -21,7 +21,6 @@ import (
"slices"
"strings"
"syscall"
- gotmpl "text/template"
"time"
"github.com/gin-contrib/cors"
@@ -1486,26 +1485,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
ch := make(chan any)
go func() {
defer close(ch)
- var sb strings.Builder
+ // var sb strings.Builder
var toolCallIndex int = 0
- var templateToolToken string
- var tmpl *gotmpl.Template
+ var tp *ToolParser
if len(req.Tools) > 0 {
- var ok bool
- templateToolToken, ok = ToolToken(m.Template.Template)
- if !ok {
- slog.Debug("no tool token found")
- }
- tmpl, ok = ToolTemplate(m)
- if !ok {
- slog.Debug("no tool template found")
- }
+ tp = NewToolParser(m)
}
- checkToolCall := false
- if len(req.Tools) > 0 {
- checkToolCall = true
- }
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
@@ -1526,50 +1512,29 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
if r.Done {
- if sb.Len() > 0 {
- res.Message.Content = sb.String()
- }
res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
- sb.WriteString(r.Content)
- if len(req.Tools) > 0 && checkToolCall {
- slog.Debug("parse tool calls", "content", sb.String(), "templateToolToken", templateToolToken)
- toolCalls, partial, err := ParseToolCalls(sb.String(), templateToolToken, tmpl)
- if err == nil {
- if partial {
- // circuit break to remove tool end token
- if len(toolCalls) > 0 {
- sb.Reset()
- }
- // If the tool call is partial, we need to wait for the next chunk
- return
- }
+ if len(req.Tools) > 0 && !tp.done {
+ fmt.Println("checking tool calls")
+ toolCalls, ok := tp.ParseToolCalls(r.Content)
+ if tp.state == PartialTool {
+ fmt.Println("partial tool, returning")
+ return
+ }
+ if ok && len(toolCalls) > 0 {
res.Message.ToolCalls = toolCalls
for i := range toolCalls {
toolCalls[i].Function.Index = toolCallIndex
toolCallIndex++
}
+ // Remove content when tool call is present
res.Message.Content = ""
- ch <- res
- // Only way to have multiple calls is to have [] which is derived or provided
- // This case occurs when the tool call is a json block - do not allow tool calls again
- if templateToolToken == "" || (templateToolToken != "" && !strings.HasPrefix(sb.String(), templateToolToken)) {
- checkToolCall = false
- }
- sb.Reset()
- return
}
}
- // If there is no template tool token, we don't need to check for tool calls after the first chunk
- if templateToolToken == "" {
- checkToolCall = false
- }
- res.Message.Content = sb.String()
- sb.Reset()
ch <- res
}); err != nil {
ch <- gin.H{"error": err.Error()}
diff --git a/server/tools.go b/server/tools.go
index aa2f729db..25d859948 100644
--- a/server/tools.go
+++ b/server/tools.go
@@ -2,50 +2,39 @@ package server
import (
"bytes"
- "encoding/json"
"errors"
"fmt"
"io"
+ "log/slog"
"strings"
gotmpl "text/template"
+ jsonv2 "github.com/go-json-experiment/json"
+
"github.com/ollama/ollama/api"
)
-func parseObjects(s string) []map[string]any {
- var objs []map[string]any
- for offset := 0; offset < len(s); {
- var obj map[string]any
- decoder := json.NewDecoder(strings.NewReader(s[offset:]))
- err := decoder.Decode(&obj)
- switch {
- case errors.Is(err, io.EOF), errors.Is(err, io.ErrUnexpectedEOF):
- return objs
- case err != nil:
- var syntax *json.SyntaxError
- var unmarshalType *json.UnmarshalTypeError
- switch {
- case errors.As(err, &syntax):
- offset += int(syntax.Offset)
- continue
- case errors.As(err, &unmarshalType):
- offset += int(unmarshalType.Offset)
- continue
- default:
- return nil
- }
- }
- offset += int(decoder.InputOffset())
- objs = append(objs, obj)
- }
- return objs
+type State int
+
+const (
+ NoTool State = iota
+ PartialTool
+ ToolCall
+)
+
+type ToolParser struct {
+ tmpl *gotmpl.Template
+ state State
+ sb *strings.Builder
+ toolPrefix string
+ done bool
}
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
// Returns parsed tool calls and a boolean indicating if the JSON is incomplete
-func parseJSONToolCalls(tmpl *gotmpl.Template, s string) ([]api.ToolCall, bool) {
+func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
var b bytes.Buffer
- if err := tmpl.Execute(&b, map[string][]api.ToolCall{
+ if err := p.tmpl.Execute(&b, map[string][]api.ToolCall{
"ToolCalls": {
{
Function: api.ToolCallFunction{
@@ -57,35 +46,18 @@ func parseJSONToolCalls(tmpl *gotmpl.Template, s string) ([]api.ToolCall, bool)
},
},
}); err != nil {
- return nil, false
+ return nil, false, false
}
- templateObjects := parseObjects(b.String())
- if len(templateObjects) == 0 {
- return nil, false
+ // slog.Debug("template", "template", b.String())
+
+ // ! this can be either a map or an array
+ var temp any
+ err := jsonv2.Unmarshal(b.Bytes(), &temp)
+ if err != nil {
+ return nil, false, false
}
- // find the keys that correspond to the name and arguments fields
- var name, arguments string
- for k, v := range templateObjects[0] {
- switch v.(type) {
- case string:
- name = k
- case map[string]any:
- arguments = k
- }
- }
-
- if name == "" || arguments == "" {
- return nil, false
- }
-
- responseObjects := parseObjects(s)
- if len(responseObjects) == 0 {
- return nil, false
- }
-
- // collect all nested objects
var collect func(any) []map[string]any
collect = func(obj any) (all []map[string]any) {
switch o := obj.(type) {
@@ -103,16 +75,63 @@ func parseJSONToolCalls(tmpl *gotmpl.Template, s string) ([]api.ToolCall, bool)
return all
}
- var objs []map[string]any
- for _, p := range responseObjects {
- objs = append(objs, collect(p)...)
+ var templateObjects []map[string]any
+ switch t := temp.(type) {
+ case map[string]any:
+ templateObjects = []map[string]any{t}
+ case []map[string]any:
+ templateObjects = t
+ // ! fallback?
+ case []any:
+ templateObjects = collect(t)
}
+ if len(templateObjects) == 0 {
+ return nil, false, false
+ }
+ // fmt.Println("template objects", templateObjects)
+
+ // find the keys that correspond to the name and arguments fields
+ var name, arguments string
+ for k, v := range templateObjects[0] {
+ switch v.(type) {
+ case string:
+ name = k
+ case map[string]any:
+ arguments = k
+ }
+ }
+
+ if name == "" || arguments == "" {
+ return nil, false, false
+ }
+ var responseObjects any
+ err = jsonv2.Unmarshal([]byte(s), &responseObjects)
+ if err != nil {
+ if errors.Is(err, io.ErrUnexpectedEOF) || err.Error() == "unexpected end of JSON input" {
+ fmt.Println("Detected partial or incomplete JSON.")
+ fmt.Println("state", p.state)
+ return nil, true, false
+ } else {
+ fmt.Printf("Other error: %v\n", err)
+ fmt.Println("exiting", p.state)
+ return nil, false, false
+ }
+ }
+
+ var objs []map[string]any
+ objs = append(objs, collect(responseObjects)...)
+ if len(objs) == 0 {
+ return nil, false, false
+ }
+
+ slog.Debug("collected objects", "count", len(objs))
var toolCalls []api.ToolCall
for _, kv := range objs {
n, nok := kv[name].(string)
a, aok := kv[arguments].(map[string]any)
if nok && aok {
+ slog.Debug("found valid tool call", "name", n)
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: n,
@@ -122,54 +141,82 @@ func parseJSONToolCalls(tmpl *gotmpl.Template, s string) ([]api.ToolCall, bool)
}
}
- return toolCalls, len(toolCalls) > 0
-}
-
-// routeToolParsing is a helper function that routes what kind of tool parsing to use
-func routeToolParsing(s string, tmpl *gotmpl.Template) ([]api.ToolCall, bool, bool) {
- if strings.HasPrefix(s, "[{") || strings.HasPrefix(s, "```") || strings.HasPrefix(s, "{") {
- if toolCalls, ok := parseJSONToolCalls(tmpl, s); ok {
- return toolCalls, false, true
- }
- // in the case the JSON never finishes, the acuumulated content should be sent downstream
- return nil, true, true
- }
- // TODO(parthsareen): add python tool call support
- return nil, false, false
+ slog.Debug("parsed tool calls", "count", len(toolCalls))
+ return toolCalls, len(toolCalls) > 0, true
}
// ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing.
// Returns tool calls, whether parsing is incomplete, and any errors.
-func ParseToolCalls(s string, toolToken string, tmpl *gotmpl.Template) ([]api.ToolCall, bool, error) {
- if tmpl == nil {
- return nil, false, fmt.Errorf("no template provided")
- }
+func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, bool) {
+ p.sb.WriteString(s)
+ s = p.sb.String()
s = strings.TrimSpace(s)
+ slog.Debug("parse tool calls", "content", s)
+
if len(s) == 0 {
- return nil, false, fmt.Errorf("empty input string")
+ return nil, false
}
- if toolToken != "" {
- if strings.HasPrefix(s, toolToken) {
- s = strings.TrimSpace(s[len(toolToken):])
- tc, _, ok := routeToolParsing(s, tmpl)
- if len(tc) == 0 || !ok {
- return nil, true, nil
- }
- return tc, false, nil
+ hasPrefix := false
+ if p.toolPrefix != "" {
+ if strings.HasPrefix(s, p.toolPrefix) {
+ s = strings.TrimSpace(s[len(p.toolPrefix):])
+ slog.Debug("tool prefix", "prefix", p.toolPrefix, "content", s)
+ p.state = PartialTool
+ hasPrefix = true
// Special token end case
- } else if strings.HasSuffix(s, toolToken[2:]) {
- tc := api.ToolCall{
- Function: api.ToolCallFunction{
- Name: toolToken,
- },
- }
- return []api.ToolCall{tc}, true, nil
+ } else if strings.HasSuffix(s, p.toolPrefix[2:]) {
+ p.state = PartialTool
+ p.sb.Reset()
+ slog.Debug("setting to no tool", "content", s)
+ return nil, false
}
}
+ tcs, partial, ok := p.parseJSONToolCalls(s)
- tc, partial, ok := routeToolParsing(s, tmpl)
- if !ok {
- return nil, false, fmt.Errorf("failed to parse tool calls for input: %q", s)
+ // TODO: figure out how to return the remaining string if not partial anymore
+ // update state
+ switch {
+ case !ok && !partial && hasPrefix:
+ p.state = PartialTool
+ case !ok && !partial:
+ p.state = NoTool
+ case !ok && partial:
+ p.state = PartialTool
+ case len(tcs) > 0:
+ p.state = ToolCall
+ }
+
+ if p.state == NoTool || p.state == ToolCall {
+ slog.Debug("resetting string builder", "state", p.state)
+ p.sb.Reset()
+ }
+
+ if !ok {
+ return nil, false
+ }
+
+ slog.Debug("returning tool calls", "tool calls", tcs)
+ fmt.Println("end state", p.state)
+ if p.toolPrefix == "" {
+ p.done = true
+ }
+
+ fmt.Println("len tcs", len(tcs))
+ return tcs, true
+}
+
+func NewToolParser(model *Model) *ToolParser {
+ templateToolPrefix, _ := ToolPrefix(model.Template.Template)
+ slog.Debug("tool prefix", "prefix", templateToolPrefix)
+ tmpl, ok := ToolTemplate(model)
+ if !ok {
+ return nil
+ }
+
+ return &ToolParser{
+ tmpl: tmpl,
+ sb: &strings.Builder{},
+ toolPrefix: templateToolPrefix,
+ done: false,
}
- return tc, partial, nil
}
diff --git a/server/tools_test.go b/server/tools_test.go
index 9250e82ee..e016232f7 100644
--- a/server/tools_test.go
+++ b/server/tools_test.go
@@ -149,9 +149,10 @@ func TestParseToolCalls(t *testing.T) {
wantErr: false,
},
{
+ // TODO: fix the spacing issue
name: "qwen with single tool call",
model: "qwen2.5-coder",
- output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
+ output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `,
token: "",
expected: []api.ToolCall{t1},
wantErr: false,
@@ -185,7 +186,7 @@ func TestParseToolCalls(t *testing.T) {
}
for _, tt := range cases {
- t.Run(tt.model, func(t *testing.T) {
+ t.Run(tt.name, func(t *testing.T) {
tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
if err != nil {
t.Fatal(err)
@@ -204,25 +205,17 @@ func TestParseToolCalls(t *testing.T) {
t.Run("parse", func(t *testing.T) {
m := &Model{Template: tmpl}
- tmpl, ok := ToolTemplate(m)
- if !ok {
- t.Fatal("no tool template found")
- }
+ tp := NewToolParser(m)
got := []api.ToolCall{}
- tokens := strings.Fields(tt.output)
- sb := strings.Builder{}
success := false
+ tokens := strings.Fields(tt.output)
for _, tok := range tokens {
- sb.WriteString(" " + tok)
- toolCalls, partial, err := ParseToolCalls(sb.String(), tt.token, tmpl)
- if err == nil {
+ s := " " + tok
+ toolCalls, ok := tp.ParseToolCalls(s)
+ if ok {
success = true
}
- if partial {
- continue
- }
got = append(got, toolCalls...)
- sb.Reset()
}
if !tt.wantErr {
@@ -237,45 +230,3 @@ func TestParseToolCalls(t *testing.T) {
})
}
}
-
-func TestParseObjects(t *testing.T) {
- tests := []struct {
- input string
- want []map[string]any
- }{
- {
- input: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
- want: []map[string]any{
- {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
- {"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, Canada"}},
- },
- },
- {
- input: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `,
- want: []map[string]any{
- {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
- },
- },
- {
- input: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, ON"}} `,
- want: []map[string]any{
- {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
- {"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, ON"}},
- },
- },
- {
- input: `{"name": "get_current_weather", "arguments": `,
- want: nil,
- },
- }
-
- for _, tc := range tests {
- t.Run(tc.input, func(t *testing.T) {
- got := parseObjects(tc.input)
-
- if diff := cmp.Diff(got, tc.want); diff != "" {
- t.Errorf("mismatch (-got +want):\n%s", diff)
- }
- })
- }
-}