diff --git a/server/routes.go b/server/routes.go index 950185cc6..1a77ff77b 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1517,13 +1517,16 @@ func (s *Server) ChatHandler(c *gin.Context) { res.LoadDuration = checkpointLoaded.Sub(checkpointStart) } - if len(req.Tools) > 0 && !tp.done { + if len(req.Tools) > 0 && tp.state != Done { fmt.Println("checking tool calls") - toolCalls, ok := tp.ParseToolCalls(r.Content) - if tp.state == PartialTool { - fmt.Println("partial tool, returning") + toolCalls, leftover, ok := tp.ParseToolCalls(r.Content) + if (tp.state == GreedyToolWithPrefix || tp.state == GreedyToolNoPrefix || tp.state == ToolSuffix) || (tp.state == ForceTools && len(toolCalls) == 0) { return } + if tp.state == ContainsPartialPrefix { + fmt.Println("sending tokens", leftover) + res.Message.Content = leftover + } if ok && len(toolCalls) > 0 { res.Message.ToolCalls = toolCalls for i := range toolCalls { @@ -1535,6 +1538,7 @@ func (s *Server) ChatHandler(c *gin.Context) { } } + fmt.Println("sending response", res.Message.Content) ch <- res }); err != nil { ch <- gin.H{"error": err.Error()} diff --git a/server/tools.go b/server/tools.go index 25d859948..248a6b78a 100644 --- a/server/tools.go +++ b/server/tools.go @@ -17,17 +17,42 @@ import ( type State int const ( - NoTool State = iota - PartialTool - ToolCall + SendTokens State = iota + GreedyToolWithPrefix + GreedyToolNoPrefix + // ToolCall + ForceTools + ToolSuffix + ContainsPartialPrefix + Done ) +func (s State) String() string { + switch s { + case SendTokens: + return "SendTokens" + case GreedyToolWithPrefix: + return "GreedyToolWithPrefix" + case GreedyToolNoPrefix: + return "GreedyToolNoPrefix" + case ForceTools: + return "ForceTools" + case ToolSuffix: + return "ToolSuffix" + case Done: + return "Done" + case ContainsPartialPrefix: + return "PartialPrefix" + default: + return fmt.Sprintf("Unknown State (%d)", s) + } +} + type ToolParser struct { tmpl *gotmpl.Template state State sb *strings.Builder toolPrefix string - done bool } // parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls. @@ -49,8 +74,6 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) { return nil, false, false } - // slog.Debug("template", "template", b.String()) - // ! this can be either a map or an array var temp any err := jsonv2.Unmarshal(b.Bytes(), &temp) @@ -88,7 +111,6 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) { if len(templateObjects) == 0 { return nil, false, false } - // fmt.Println("template objects", templateObjects) // find the keys that correspond to the name and arguments fields var name, arguments string @@ -142,81 +164,134 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) { } slog.Debug("parsed tool calls", "count", len(toolCalls)) - return toolCalls, len(toolCalls) > 0, true + return toolCalls, false, true +} + +func (p *ToolParser) updateState(ok bool, partial bool, tcs []api.ToolCall) { + switch { + case !ok && !partial && p.state == ForceTools: + fmt.Println("Case: !ok && !partial && ForceTools - staying in force tools, resetting buffer") + // force partial tool if we have a prefix + // no op and stay in force tools + p.sb.Reset() + case !ok && !partial: + fmt.Println("Case: !ok && !partial") + fmt.Println("state", p.state) + if p.state == GreedyToolNoPrefix { + fmt.Println(" Subcase: GreedyToolNoPrefix - marking as done") + p.state = Done + } + if p.state == GreedyToolWithPrefix { + fmt.Println(" Subcase: GreedyToolWithPrefix - switching to SendTokens") + p.state = SendTokens + } + p.sb.Reset() + case !ok && partial: + fmt.Println("Case: !ok && partial - accumulating partial content") + + // ! acucumulate + + case len(tcs) > 0: + fmt.Println("Case: tool calls found") + // do not parse again in the greedy JSON case as soon as we have a tool call + if p.state == GreedyToolWithPrefix { + p.state = SendTokens + } else if p.state == GreedyToolNoPrefix { + fmt.Println(" Subcase: Greedy modes - marking done and switching to SendTokens") + p.state = Done + } + p.sb.Reset() + } } // ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing. // Returns tool calls, whether parsing is incomplete, and any errors. -func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, bool) { +func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, string, bool) { p.sb.WriteString(s) s = p.sb.String() s = strings.TrimSpace(s) slog.Debug("parse tool calls", "content", s) if len(s) == 0 { - return nil, false + return nil, "", false } - hasPrefix := false + s, hasPrefix := strings.CutPrefix(s, p.toolPrefix) + fmt.Println("hasPrefix", hasPrefix) + var tcs []api.ToolCall + var partial bool + var ok bool + if p.toolPrefix != "" { - if strings.HasPrefix(s, p.toolPrefix) { - s = strings.TrimSpace(s[len(p.toolPrefix):]) - slog.Debug("tool prefix", "prefix", p.toolPrefix, "content", s) - p.state = PartialTool - hasPrefix = true - // Special token end case - } else if strings.HasSuffix(s, p.toolPrefix[2:]) { - p.state = PartialTool - p.sb.Reset() + if hasPrefix { + p.state = ForceTools + slog.Debug("tool prefix in prefix", "prefix", p.toolPrefix, "content", s) + // partial tool possibly + } else if strings.HasPrefix(p.toolPrefix, s) { + slog.Debug("tool prefix partially", "prefix", p.toolPrefix, "content", s) + // TODO: could possibly err maybe this should be greedy instead? + p.state = ForceTools + return nil, "", false + } else if strings.Contains(s, p.toolPrefix) { + idx := strings.Index(s, p.toolPrefix) + if idx != -1 { + // still keeps the prefix + p.state = ContainsPartialPrefix + p.sb.Reset() + p.sb.WriteString(s[idx:]) + return nil, s[:idx], false + } + } + // Special token end case + // if s, ok := strings.CutSuffix(s, p.toolPrefix[2:]); ok { + if strings.HasSuffix(s, p.toolPrefix[2:]) { + // can be with string or just the token + if hasPrefix { + s = strings.TrimSpace(s[:len(s)-(len(p.toolPrefix)+1)]) + } else { + p.state = ToolSuffix + p.sb.Reset() + return nil, "", false + } slog.Debug("setting to no tool", "content", s) - return nil, false } } - tcs, partial, ok := p.parseJSONToolCalls(s) - - // TODO: figure out how to return the remaining string if not partial anymore - // update state - switch { - case !ok && !partial && hasPrefix: - p.state = PartialTool - case !ok && !partial: - p.state = NoTool - case !ok && partial: - p.state = PartialTool - case len(tcs) > 0: - p.state = ToolCall + fmt.Println("s before parsing", s) + if p.state == SendTokens { + fmt.Println("returning nil cause of send tokens") + return nil, "", false } - - if p.state == NoTool || p.state == ToolCall { - slog.Debug("resetting string builder", "state", p.state) - p.sb.Reset() - } - - if !ok { - return nil, false - } - + tcs, partial, ok = p.parseJSONToolCalls(s) slog.Debug("returning tool calls", "tool calls", tcs) fmt.Println("end state", p.state) - if p.toolPrefix == "" { - p.done = true - } fmt.Println("len tcs", len(tcs)) - return tcs, true + p.updateState(ok, partial, tcs) + if !ok { + return nil, "", false + } + + return tcs, "", true } func NewToolParser(model *Model) *ToolParser { templateToolPrefix, _ := ToolPrefix(model.Template.Template) + templateToolPrefix = strings.TrimSpace(templateToolPrefix) slog.Debug("tool prefix", "prefix", templateToolPrefix) tmpl, ok := ToolTemplate(model) if !ok { return nil } + var state State + if templateToolPrefix == "" { + state = GreedyToolNoPrefix + } else { + state = GreedyToolWithPrefix + } return &ToolParser{ tmpl: tmpl, sb: &strings.Builder{}, toolPrefix: templateToolPrefix, - done: false, + state: state, } } diff --git a/server/tools_test.go b/server/tools_test.go index e016232f7..ac55c1558 100644 --- a/server/tools_test.go +++ b/server/tools_test.go @@ -51,7 +51,7 @@ func TestParseToolCalls(t *testing.T) { name string model string output string - token string + prefix string expected []api.ToolCall wantErr bool }{ @@ -59,23 +59,50 @@ func TestParseToolCalls(t *testing.T) { name: "mistral invalid json", model: "mistral", output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`, - token: "[TOOL_CALLS]", + prefix: "[TOOL_CALLS]", expected: []api.ToolCall{}, wantErr: true, }, { - name: "mistral valid json", + name: "mistral multiple tool calls - no prefix", + model: "mistral", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + prefix: "[TOOL_CALLS]", + expected: []api.ToolCall{t1, t2}, + wantErr: false, + }, + { + name: "mistral tool calls with text in between - no prefix", + model: "mistral", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] + model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + prefix: "[TOOL_CALLS]", + expected: []api.ToolCall{t1, t2}, + wantErr: false, + }, + { + name: "mistral valid json - with prefix", model: "mistral", output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - token: "[TOOL_CALLS]", + prefix: "[TOOL_CALLS]", expected: []api.ToolCall{t1, t2}, wantErr: false, }, + { + // In this case we'd be ignoring the text in between and just returning the tool calls + name: "mistral valid json with text in between - with prefix", + model: "mistral", + output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] + model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + prefix: "[TOOL_CALLS]", + expected: []api.ToolCall{t1, t2, t1, t2}, + wantErr: false, + }, { name: "mistral incomplete json", model: "mistral", output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `, - token: "[TOOL_CALLS]", + prefix: "[TOOL_CALLS]", expected: []api.ToolCall{}, wantErr: true, }, @@ -85,7 +112,7 @@ func TestParseToolCalls(t *testing.T) { output: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - token: "[TOOL_CALLS]", + prefix: "[TOOL_CALLS]", expected: []api.ToolCall{}, wantErr: true, }, @@ -93,7 +120,7 @@ func TestParseToolCalls(t *testing.T) { name: "mistral without tool token - tool first", model: "mistral", output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - token: "[TOOL_CALLS]", + prefix: "[TOOL_CALLS]", expected: []api.ToolCall{t1, t2}, wantErr: false, }, @@ -118,7 +145,7 @@ func TestParseToolCalls(t *testing.T) { } ] ` + "```", - token: "Action:", + prefix: "Action: ```json", expected: []api.ToolCall{t1, t2}, wantErr: false, }, @@ -126,7 +153,7 @@ func TestParseToolCalls(t *testing.T) { name: "firefunction with functools", model: "firefunction", output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - token: "functools", + prefix: "functools", expected: []api.ToolCall{t1, t2}, wantErr: false, }, @@ -136,7 +163,7 @@ func TestParseToolCalls(t *testing.T) { output: ` {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, - token: "", + prefix: "", expected: []api.ToolCall{t1}, wantErr: false, }, @@ -144,16 +171,15 @@ func TestParseToolCalls(t *testing.T) { name: "xlam with tool_calls wrapper", model: "xlam", output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, - token: "", + prefix: "", expected: []api.ToolCall{t1, t2}, wantErr: false, }, { - // TODO: fix the spacing issue name: "qwen with single tool call", model: "qwen2.5-coder", - output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, - token: "", + output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}`, + prefix: "", expected: []api.ToolCall{t1}, wantErr: false, }, @@ -161,15 +187,31 @@ func TestParseToolCalls(t *testing.T) { name: "qwen with invalid tool token", model: "qwen2.5-coder", output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - token: "[TOOL_CALLS]", + prefix: "[TOOL_CALLS]", expected: []api.ToolCall{t1, t2}, wantErr: false, }, + { + name: "qwen3 with single tool call and thinking", + model: "qwen3", + output: `Okay, let me think what tool we should use...{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}`, + prefix: "", + expected: []api.ToolCall{t1}, + wantErr: false, + }, + { + name: "qwen3 with single tool call and thinking spaces", + model: "qwen3", + output: `Okay, let me think what tool we should use... {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + prefix: "", + expected: []api.ToolCall{t1}, + wantErr: false, + }, { name: "qwen with no tool calls", model: "qwen2.5-coder", output: " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", - token: "", + prefix: "", expected: []api.ToolCall{}, wantErr: true, }, @@ -211,11 +253,15 @@ func TestParseToolCalls(t *testing.T) { tokens := strings.Fields(tt.output) for _, tok := range tokens { s := " " + tok - toolCalls, ok := tp.ParseToolCalls(s) - if ok { - success = true + var toolCalls []api.ToolCall + var ok bool + if tp.state != Done { + toolCalls, _, ok = tp.ParseToolCalls(s) + if ok { + success = true + } + got = append(got, toolCalls...) } - got = append(got, toolCalls...) } if !tt.wantErr { @@ -230,3 +276,5 @@ func TestParseToolCalls(t *testing.T) { }) } } + +// TODO: add tests to check string sent not just tool