diff --git a/go.mod b/go.mod index 283286b7d..d9de611ba 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-json-experiment/json v0.0.0-20250417205406-170dfdcf87d1 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/google/flatbuffers v24.3.25+incompatible // indirect github.com/kr/text v0.2.0 // indirect diff --git a/go.sum b/go.sum index 5755616f6..780a76f10 100644 --- a/go.sum +++ b/go.sum @@ -69,6 +69,8 @@ github.com/go-fonts/latin-modern v0.2.0/go.mod h1:rQVLdDMK+mK1xscDwsqM5J8U2jrRa3 github.com/go-fonts/liberation v0.1.1/go.mod h1:K6qoJYypsmfVjWg8KOVDQhLc8UDgIK2HYqyqAO9z7GY= github.com/go-fonts/stix v0.1.0/go.mod h1:w/c1f0ldAUlJmLBvlbkvVXLAD+tAMqobIIQpmnUIzUY= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= +github.com/go-json-experiment/json v0.0.0-20250417205406-170dfdcf87d1 h1:+VexzzkMLb1tnvpuQdGT/DicIRW7MN8ozsXqBMgp0Hk= +github.com/go-json-experiment/json v0.0.0-20250417205406-170dfdcf87d1/go.mod h1:TiCD2a1pcmjd7YnhGH0f/zKNcCD06B029pHhzV23c2M= github.com/go-latex/latex v0.0.0-20210118124228-b3d85cf34e07/go.mod h1:CO1AlKB2CSIqUrmQPqA0gdRIlnLEY0gK5JGjh37zN5U= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= diff --git a/server/model.go b/server/model.go index b5c91ef1e..eb28d3733 100644 --- a/server/model.go +++ b/server/model.go @@ -210,7 +210,16 @@ func nodeContainsToolCalls(n *parse.IfNode) bool { return false } -func ToolToken(tmpl *gotmpl.Template) (string, bool) { +func ToolPrefix2(tmpl *gotmpl.Template) (string, bool) { + tokenText, ok := extractToolCallsTemplate(tmpl) + if !ok { + return "", false + } + tokenText = strings.TrimSpace(tokenText) + return tokenText, true +} + +func ToolPrefix(tmpl *gotmpl.Template) (string, bool) { tokenText, ok := extractToolCallsTemplate(tmpl) if !ok { return "", false diff --git a/server/model_test.go b/server/model_test.go index 498fdb408..8fd19d2db 100644 --- a/server/model_test.go +++ b/server/model_test.go @@ -80,7 +80,7 @@ func TestToolToken(t *testing.T) { if err != nil { t.Fatalf("failed to parse template: %v", err) } - got, ok := ToolToken(tmpl) + got, ok := ToolPrefix(tmpl) if got != tt.want { t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want) } diff --git a/server/routes.go b/server/routes.go index c40ab211c..950185cc6 100644 --- a/server/routes.go +++ b/server/routes.go @@ -21,7 +21,6 @@ import ( "slices" "strings" "syscall" - gotmpl "text/template" "time" "github.com/gin-contrib/cors" @@ -1486,26 +1485,13 @@ func (s *Server) ChatHandler(c *gin.Context) { ch := make(chan any) go func() { defer close(ch) - var sb strings.Builder + // var sb strings.Builder var toolCallIndex int = 0 - var templateToolToken string - var tmpl *gotmpl.Template + var tp *ToolParser if len(req.Tools) > 0 { - var ok bool - templateToolToken, ok = ToolToken(m.Template.Template) - if !ok { - slog.Debug("no tool token found") - } - tmpl, ok = ToolTemplate(m) - if !ok { - slog.Debug("no tool template found") - } + tp = NewToolParser(m) } - checkToolCall := false - if len(req.Tools) > 0 { - checkToolCall = true - } if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: prompt, Images: images, @@ -1526,50 +1512,29 @@ func (s *Server) ChatHandler(c *gin.Context) { } if r.Done { - if sb.Len() > 0 { - res.Message.Content = sb.String() - } res.DoneReason = r.DoneReason.String() res.TotalDuration = time.Since(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart) } - sb.WriteString(r.Content) - if len(req.Tools) > 0 && checkToolCall { - slog.Debug("parse tool calls", "content", sb.String(), "templateToolToken", templateToolToken) - toolCalls, partial, err := ParseToolCalls(sb.String(), templateToolToken, tmpl) - if err == nil { - if partial { - // circuit break to remove tool end token - if len(toolCalls) > 0 { - sb.Reset() - } - // If the tool call is partial, we need to wait for the next chunk - return - } + if len(req.Tools) > 0 && !tp.done { + fmt.Println("checking tool calls") + toolCalls, ok := tp.ParseToolCalls(r.Content) + if tp.state == PartialTool { + fmt.Println("partial tool, returning") + return + } + if ok && len(toolCalls) > 0 { res.Message.ToolCalls = toolCalls for i := range toolCalls { toolCalls[i].Function.Index = toolCallIndex toolCallIndex++ } + // Remove content when tool call is present res.Message.Content = "" - ch <- res - // Only way to have multiple calls is to have [] which is derived or provided - // This case occurs when the tool call is a json block - do not allow tool calls again - if templateToolToken == "" || (templateToolToken != "" && !strings.HasPrefix(sb.String(), templateToolToken)) { - checkToolCall = false - } - sb.Reset() - return } } - // If there is no template tool token, we don't need to check for tool calls after the first chunk - if templateToolToken == "" { - checkToolCall = false - } - res.Message.Content = sb.String() - sb.Reset() ch <- res }); err != nil { ch <- gin.H{"error": err.Error()} diff --git a/server/tools.go b/server/tools.go index aa2f729db..25d859948 100644 --- a/server/tools.go +++ b/server/tools.go @@ -2,50 +2,39 @@ package server import ( "bytes" - "encoding/json" "errors" "fmt" "io" + "log/slog" "strings" gotmpl "text/template" + jsonv2 "github.com/go-json-experiment/json" + "github.com/ollama/ollama/api" ) -func parseObjects(s string) []map[string]any { - var objs []map[string]any - for offset := 0; offset < len(s); { - var obj map[string]any - decoder := json.NewDecoder(strings.NewReader(s[offset:])) - err := decoder.Decode(&obj) - switch { - case errors.Is(err, io.EOF), errors.Is(err, io.ErrUnexpectedEOF): - return objs - case err != nil: - var syntax *json.SyntaxError - var unmarshalType *json.UnmarshalTypeError - switch { - case errors.As(err, &syntax): - offset += int(syntax.Offset) - continue - case errors.As(err, &unmarshalType): - offset += int(unmarshalType.Offset) - continue - default: - return nil - } - } - offset += int(decoder.InputOffset()) - objs = append(objs, obj) - } - return objs +type State int + +const ( + NoTool State = iota + PartialTool + ToolCall +) + +type ToolParser struct { + tmpl *gotmpl.Template + state State + sb *strings.Builder + toolPrefix string + done bool } // parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls. // Returns parsed tool calls and a boolean indicating if the JSON is incomplete -func parseJSONToolCalls(tmpl *gotmpl.Template, s string) ([]api.ToolCall, bool) { +func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) { var b bytes.Buffer - if err := tmpl.Execute(&b, map[string][]api.ToolCall{ + if err := p.tmpl.Execute(&b, map[string][]api.ToolCall{ "ToolCalls": { { Function: api.ToolCallFunction{ @@ -57,35 +46,18 @@ func parseJSONToolCalls(tmpl *gotmpl.Template, s string) ([]api.ToolCall, bool) }, }, }); err != nil { - return nil, false + return nil, false, false } - templateObjects := parseObjects(b.String()) - if len(templateObjects) == 0 { - return nil, false + // slog.Debug("template", "template", b.String()) + + // ! this can be either a map or an array + var temp any + err := jsonv2.Unmarshal(b.Bytes(), &temp) + if err != nil { + return nil, false, false } - // find the keys that correspond to the name and arguments fields - var name, arguments string - for k, v := range templateObjects[0] { - switch v.(type) { - case string: - name = k - case map[string]any: - arguments = k - } - } - - if name == "" || arguments == "" { - return nil, false - } - - responseObjects := parseObjects(s) - if len(responseObjects) == 0 { - return nil, false - } - - // collect all nested objects var collect func(any) []map[string]any collect = func(obj any) (all []map[string]any) { switch o := obj.(type) { @@ -103,16 +75,63 @@ func parseJSONToolCalls(tmpl *gotmpl.Template, s string) ([]api.ToolCall, bool) return all } - var objs []map[string]any - for _, p := range responseObjects { - objs = append(objs, collect(p)...) + var templateObjects []map[string]any + switch t := temp.(type) { + case map[string]any: + templateObjects = []map[string]any{t} + case []map[string]any: + templateObjects = t + // ! fallback? + case []any: + templateObjects = collect(t) } + if len(templateObjects) == 0 { + return nil, false, false + } + // fmt.Println("template objects", templateObjects) + + // find the keys that correspond to the name and arguments fields + var name, arguments string + for k, v := range templateObjects[0] { + switch v.(type) { + case string: + name = k + case map[string]any: + arguments = k + } + } + + if name == "" || arguments == "" { + return nil, false, false + } + var responseObjects any + err = jsonv2.Unmarshal([]byte(s), &responseObjects) + if err != nil { + if errors.Is(err, io.ErrUnexpectedEOF) || err.Error() == "unexpected end of JSON input" { + fmt.Println("Detected partial or incomplete JSON.") + fmt.Println("state", p.state) + return nil, true, false + } else { + fmt.Printf("Other error: %v\n", err) + fmt.Println("exiting", p.state) + return nil, false, false + } + } + + var objs []map[string]any + objs = append(objs, collect(responseObjects)...) + if len(objs) == 0 { + return nil, false, false + } + + slog.Debug("collected objects", "count", len(objs)) var toolCalls []api.ToolCall for _, kv := range objs { n, nok := kv[name].(string) a, aok := kv[arguments].(map[string]any) if nok && aok { + slog.Debug("found valid tool call", "name", n) toolCalls = append(toolCalls, api.ToolCall{ Function: api.ToolCallFunction{ Name: n, @@ -122,54 +141,82 @@ func parseJSONToolCalls(tmpl *gotmpl.Template, s string) ([]api.ToolCall, bool) } } - return toolCalls, len(toolCalls) > 0 -} - -// routeToolParsing is a helper function that routes what kind of tool parsing to use -func routeToolParsing(s string, tmpl *gotmpl.Template) ([]api.ToolCall, bool, bool) { - if strings.HasPrefix(s, "[{") || strings.HasPrefix(s, "```") || strings.HasPrefix(s, "{") { - if toolCalls, ok := parseJSONToolCalls(tmpl, s); ok { - return toolCalls, false, true - } - // in the case the JSON never finishes, the acuumulated content should be sent downstream - return nil, true, true - } - // TODO(parthsareen): add python tool call support - return nil, false, false + slog.Debug("parsed tool calls", "count", len(toolCalls)) + return toolCalls, len(toolCalls) > 0, true } // ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing. // Returns tool calls, whether parsing is incomplete, and any errors. -func ParseToolCalls(s string, toolToken string, tmpl *gotmpl.Template) ([]api.ToolCall, bool, error) { - if tmpl == nil { - return nil, false, fmt.Errorf("no template provided") - } +func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, bool) { + p.sb.WriteString(s) + s = p.sb.String() s = strings.TrimSpace(s) + slog.Debug("parse tool calls", "content", s) + if len(s) == 0 { - return nil, false, fmt.Errorf("empty input string") + return nil, false } - if toolToken != "" { - if strings.HasPrefix(s, toolToken) { - s = strings.TrimSpace(s[len(toolToken):]) - tc, _, ok := routeToolParsing(s, tmpl) - if len(tc) == 0 || !ok { - return nil, true, nil - } - return tc, false, nil + hasPrefix := false + if p.toolPrefix != "" { + if strings.HasPrefix(s, p.toolPrefix) { + s = strings.TrimSpace(s[len(p.toolPrefix):]) + slog.Debug("tool prefix", "prefix", p.toolPrefix, "content", s) + p.state = PartialTool + hasPrefix = true // Special token end case - } else if strings.HasSuffix(s, toolToken[2:]) { - tc := api.ToolCall{ - Function: api.ToolCallFunction{ - Name: toolToken, - }, - } - return []api.ToolCall{tc}, true, nil + } else if strings.HasSuffix(s, p.toolPrefix[2:]) { + p.state = PartialTool + p.sb.Reset() + slog.Debug("setting to no tool", "content", s) + return nil, false } } + tcs, partial, ok := p.parseJSONToolCalls(s) - tc, partial, ok := routeToolParsing(s, tmpl) - if !ok { - return nil, false, fmt.Errorf("failed to parse tool calls for input: %q", s) + // TODO: figure out how to return the remaining string if not partial anymore + // update state + switch { + case !ok && !partial && hasPrefix: + p.state = PartialTool + case !ok && !partial: + p.state = NoTool + case !ok && partial: + p.state = PartialTool + case len(tcs) > 0: + p.state = ToolCall + } + + if p.state == NoTool || p.state == ToolCall { + slog.Debug("resetting string builder", "state", p.state) + p.sb.Reset() + } + + if !ok { + return nil, false + } + + slog.Debug("returning tool calls", "tool calls", tcs) + fmt.Println("end state", p.state) + if p.toolPrefix == "" { + p.done = true + } + + fmt.Println("len tcs", len(tcs)) + return tcs, true +} + +func NewToolParser(model *Model) *ToolParser { + templateToolPrefix, _ := ToolPrefix(model.Template.Template) + slog.Debug("tool prefix", "prefix", templateToolPrefix) + tmpl, ok := ToolTemplate(model) + if !ok { + return nil + } + + return &ToolParser{ + tmpl: tmpl, + sb: &strings.Builder{}, + toolPrefix: templateToolPrefix, + done: false, } - return tc, partial, nil } diff --git a/server/tools_test.go b/server/tools_test.go index 9250e82ee..e016232f7 100644 --- a/server/tools_test.go +++ b/server/tools_test.go @@ -149,9 +149,10 @@ func TestParseToolCalls(t *testing.T) { wantErr: false, }, { + // TODO: fix the spacing issue name: "qwen with single tool call", model: "qwen2.5-coder", - output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}`, + output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, token: "", expected: []api.ToolCall{t1}, wantErr: false, @@ -185,7 +186,7 @@ func TestParseToolCalls(t *testing.T) { } for _, tt := range cases { - t.Run(tt.model, func(t *testing.T) { + t.Run(tt.name, func(t *testing.T) { tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String()) if err != nil { t.Fatal(err) @@ -204,25 +205,17 @@ func TestParseToolCalls(t *testing.T) { t.Run("parse", func(t *testing.T) { m := &Model{Template: tmpl} - tmpl, ok := ToolTemplate(m) - if !ok { - t.Fatal("no tool template found") - } + tp := NewToolParser(m) got := []api.ToolCall{} - tokens := strings.Fields(tt.output) - sb := strings.Builder{} success := false + tokens := strings.Fields(tt.output) for _, tok := range tokens { - sb.WriteString(" " + tok) - toolCalls, partial, err := ParseToolCalls(sb.String(), tt.token, tmpl) - if err == nil { + s := " " + tok + toolCalls, ok := tp.ParseToolCalls(s) + if ok { success = true } - if partial { - continue - } got = append(got, toolCalls...) - sb.Reset() } if !tt.wantErr { @@ -237,45 +230,3 @@ func TestParseToolCalls(t *testing.T) { }) } } - -func TestParseObjects(t *testing.T) { - tests := []struct { - input string - want []map[string]any - }{ - { - input: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - want: []map[string]any{ - {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}}, - {"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, Canada"}}, - }, - }, - { - input: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, - want: []map[string]any{ - {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}}, - }, - }, - { - input: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, ON"}} `, - want: []map[string]any{ - {"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}}, - {"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, ON"}}, - }, - }, - { - input: `{"name": "get_current_weather", "arguments": `, - want: nil, - }, - } - - for _, tc := range tests { - t.Run(tc.input, func(t *testing.T) { - got := parseObjects(tc.input) - - if diff := cmp.Diff(got, tc.want); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) - } - }) - } -}