From 6cb74940616f7f9f618950869111802af4c94123 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Wed, 7 May 2025 19:35:11 -0700 Subject: [PATCH] checkpoint for new parser TODO: - cleanup routes interface - internal/external states --- go.mod | 2 +- server/routes.go | 14 +++ server/tools.go | 51 ++++++++- server/tools_test.go | 265 +++++++++++++++++++++++++++---------------- 4 files changed, 231 insertions(+), 101 deletions(-) diff --git a/go.mod b/go.mod index d9de611ba..b92df2ef9 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 github.com/dlclark/regexp2 v1.11.4 github.com/emirpasic/gods/v2 v2.0.0-alpha + github.com/go-json-experiment/json v0.0.0-20250417205406-170dfdcf87d1 github.com/google/go-cmp v0.6.0 github.com/mattn/go-runewidth v0.0.14 github.com/nlpodyssey/gopickle v0.3.0 @@ -35,7 +36,6 @@ require ( github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/go-json-experiment/json v0.0.0-20250417205406-170dfdcf87d1 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/google/flatbuffers v24.3.25+incompatible // indirect github.com/kr/text v0.2.0 // indirect diff --git a/server/routes.go b/server/routes.go index 1a77ff77b..3fdd24a61 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1519,14 +1519,28 @@ func (s *Server) ChatHandler(c *gin.Context) { if len(req.Tools) > 0 && tp.state != Done { fmt.Println("checking tool calls") + /* + This should give us a few return things we shouldnt have to build things up. + 1. tool calls if any + 2. leftover tokens if any - this happens in the partial case where we have a prefix inside a string + 3. if we need to skip this loop and not send anything back + + + between these three things, we should just be switching on either the state or something to capture this + potentially consider a difference between internal and external state + */ toolCalls, leftover, ok := tp.ParseToolCalls(r.Content) + + // todo: this should just be one check/state coming back from the parse tool calls if (tp.state == GreedyToolWithPrefix || tp.state == GreedyToolNoPrefix || tp.state == ToolSuffix) || (tp.state == ForceTools && len(toolCalls) == 0) { return } + // todo: our second state also should just be a var if tp.state == ContainsPartialPrefix { fmt.Println("sending tokens", leftover) res.Message.Content = leftover } + // TODO: this can be done inside the parse tool calls if ok && len(toolCalls) > 0 { res.Message.ToolCalls = toolCalls for i := range toolCalls { diff --git a/server/tools.go b/server/tools.go index d4401270b..faa2cb9e2 100644 --- a/server/tools.go +++ b/server/tools.go @@ -20,7 +20,6 @@ const ( SendTokens State = iota GreedyToolWithPrefix GreedyToolNoPrefix - // ToolCall ForceTools ToolSuffix ContainsPartialPrefix @@ -56,7 +55,7 @@ type ToolParser struct { } // parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls. -// Returns parsed tool calls and a boolean indicating if the JSON is incomplete +// Returns parsed tool calls, a boolean indicating if the JSON is incomplete, and a boolean indicating if the tool calls were found func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) { var b bytes.Buffer if err := p.tmpl.Execute(&b, map[string][]api.ToolCall{ @@ -93,6 +92,9 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) { for _, v := range o { all = append(all, collect(v)...) } + default: + // TODO: err or fallback + return nil } return all @@ -241,9 +243,37 @@ func (p *ToolParser) updateInputState(s string, hasPrefix bool) (string, bool) { return s, true } +func (p *ToolParser) sendTokens(original string, hasPrefix bool) (string, bool) { + if p.state == SendTokens { + return "", false + } + if p.state == ContainsPartialPrefix { + idx := strings.Index(original, p.toolPrefix) + if idx != -1 { + // still keeps the prefix + p.state = ContainsPartialPrefix + return original[:idx], false + } else { + fmt.Println("some weird state") + } + } else if strings.HasSuffix(original, p.toolPrefix[2:]) { + // can be with string or just the token + if hasPrefix { + original = strings.TrimSpace(original[:len(original)-(len(p.toolPrefix)+1)]) + return original, false + } else { + p.state = ToolSuffix + p.sb.Reset() + } + } + + return "", true +} + // ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing. // Returns tool calls, whether parsing is incomplete, and any errors. func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, string, bool) { + original := s // append input p.sb.WriteString(s) s = p.sb.String() @@ -258,7 +288,20 @@ func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, string, bool) { s, ok := p.updateInputState(s, hasPrefix) if !ok { if p.state == ContainsPartialPrefix { - return nil, s, false + idx := strings.Index(original, p.toolPrefix) + if idx != -1 { + // still keeps the prefix + p.state = ContainsPartialPrefix + // p.sb.Reset() + // p.sb.WriteString(original[idx:]) + return nil, original[:idx], false + } else { + fmt.Println("some weird state") + } + // } + // s, ok = p.sendTokens(original, hasPrefix) + // if ok { + // return nil, s, true } return nil, "", false } @@ -292,7 +335,7 @@ func NewToolParser(model *Model) *ToolParser { } else { state = GreedyToolWithPrefix } - fmt.Println("state", state) + fmt.Println("setup state", state) return &ToolParser{ tmpl: tmpl, sb: &strings.Builder{}, diff --git a/server/tools_test.go b/server/tools_test.go index 7cffa5c25..29d057c75 100644 --- a/server/tools_test.go +++ b/server/tools_test.go @@ -48,40 +48,45 @@ func TestParseToolCalls(t *testing.T) { } cases := []struct { - name string - model string - output string - expected []api.ToolCall - wantErr bool + name string + model string + output string + expectedToolCall []api.ToolCall + expectedTokens string + wantErr bool }{ { - name: "mistral invalid json", - model: "mistral", - output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`, - expected: []api.ToolCall{}, - wantErr: true, + name: "mistral invalid json", + model: "mistral", + output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: "", + wantErr: true, }, { - name: "mistral multiple tool calls - no prefix", - model: "mistral", - output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expected: []api.ToolCall{t1, t2}, - wantErr: false, + name: "mistral multiple tool calls - no prefix", + model: "mistral", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + wantErr: false, }, { name: "mistral tool calls with text in between - no prefix", model: "mistral", output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expected: []api.ToolCall{t1, t2}, - wantErr: false, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: `model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + wantErr: false, }, { - name: "mistral valid json - with prefix", - model: "mistral", - output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expected: []api.ToolCall{t1, t2}, - wantErr: false, + name: "mistral valid json - with prefix", + model: "mistral", + output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + wantErr: false, }, { // In this case we'd be ignoring the text in between and just returning the tool calls @@ -89,15 +94,17 @@ func TestParseToolCalls(t *testing.T) { model: "mistral", output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expected: []api.ToolCall{t1, t2, t1, t2}, - wantErr: false, + expectedToolCall: []api.ToolCall{t1, t2, t1, t2}, + expectedTokens: "", + wantErr: false, }, { - name: "mistral incomplete json", - model: "mistral", - output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `, - expected: []api.ToolCall{}, - wantErr: true, + name: "mistral incomplete json", + model: "mistral", + output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `, + expectedToolCall: []api.ToolCall{}, + expectedTokens: "", + wantErr: true, }, { name: "mistral without tool token", @@ -105,15 +112,17 @@ func TestParseToolCalls(t *testing.T) { output: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expected: []api.ToolCall{}, - wantErr: true, + expectedToolCall: []api.ToolCall{}, + expectedTokens: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + wantErr: true, }, { - name: "mistral without tool token - tool first", - model: "mistral", - output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expected: []api.ToolCall{t1, t2}, - wantErr: false, + name: "mistral without tool token - tool first", + model: "mistral", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + wantErr: false, }, { name: "command-r-plus with json block", @@ -136,15 +145,17 @@ func TestParseToolCalls(t *testing.T) { } ] ` + "```", - expected: []api.ToolCall{t1, t2}, - wantErr: false, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + wantErr: false, }, { - name: "firefunction with functools", - model: "firefunction", - output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expected: []api.ToolCall{t1, t2}, - wantErr: false, + name: "firefunction with functools", + model: "firefunction", + output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + wantErr: false, }, { name: "llama3 with tool call tags", @@ -152,64 +163,106 @@ func TestParseToolCalls(t *testing.T) { output: ` {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, - expected: []api.ToolCall{t1}, - wantErr: false, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: "", + wantErr: false, }, { - name: "xlam with tool_calls wrapper", - model: "xlam", - output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, - expected: []api.ToolCall{t1, t2}, - wantErr: false, + name: "xlam with tool_calls wrapper", + model: "xlam", + output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + wantErr: false, }, { - name: "qwen with single tool call", - model: "qwen2.5-coder", - output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}`, - expected: []api.ToolCall{t1}, - wantErr: false, + name: "qwen2.5 with single tool call", + model: "qwen2.5-coder", + output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}`, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: "", + wantErr: false, }, { - name: "qwen with invalid tool token", - model: "qwen2.5-coder", - output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expected: []api.ToolCall{t1, t2}, - wantErr: false, + name: "qwen with invalid tool token", + model: "qwen2.5-coder", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", + wantErr: false, }, { - name: "qwen3 with single tool call and thinking", - model: "qwen3", - output: `Okay, let me think what tool we should use...{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}`, - expected: []api.ToolCall{t1}, - wantErr: false, + // tests the leftover logic as well + name: "qwen3 with single tool call and thinking", + model: "qwen3", + output: `Okay, let me think what tool we should use...{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}`, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: "Okay, let me think what tool we should use...", + wantErr: false, }, { - name: "qwen3 with single tool call and thinking spaces", - model: "qwen3", - output: `Okay, let me think what tool we should use... {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, - expected: []api.ToolCall{t1}, - wantErr: false, + name: "qwen3 with single tool call and thinking spaces", + model: "qwen3", + output: `Okay, let me think what tool we should use... {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: "Okay, let me think what tool we should use...", + wantErr: false, + }, + // { + // name: "qwen3 testing", + // model: "qwen3", + // output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + // expectedToolCall: []api.ToolCall{}, + // expectedTokens: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + // wantErr: true, + // }, + // { + // name: "qwen3 testing 2", + // model: "qwen3", + // output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + // expectedToolCall: []api.ToolCall{t1}, + // expectedTokens: ``, + // wantErr: true, + // }, + { + name: "qwen with no tool calls", + model: "qwen2.5-coder", + output: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", + expectedToolCall: []api.ToolCall{}, + expectedTokens: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", + wantErr: true, }, { - name: "qwen with no tool calls", - model: "qwen2.5-coder", - output: " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", - expected: []api.ToolCall{}, - wantErr: true, + name: "llama3.2 with tool call - no prefix", + model: "llama3.2", + output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: "", + wantErr: false, }, { - name: "llama3.2 with tool call - no prefix", - model: "llama3.2", - output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`, - expected: []api.ToolCall{t1}, - wantErr: false, + name: "llama3.2 with incomplete tool call - no prefix", + model: "llama3.2", + output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, `, + expectedToolCall: []api.ToolCall{}, + expectedTokens: "", + wantErr: true, }, { - name: "llama3.2 with tool call - in middle", - model: "llama3.2", - output: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`, - expected: []api.ToolCall{}, - wantErr: true, + name: "llama3.2 with tool call - in middle", + model: "llama3.2", + output: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`, + wantErr: true, + }, + { + name: "llama3.2 - fake tool prefix", + model: "llama3.2", + output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`, + wantErr: true, }, } @@ -232,12 +285,10 @@ func TestParseToolCalls(t *testing.T) { t.Run("template", func(t *testing.T) { actual := &bytes.Buffer{} // Create new buffer for each test - t.Log("template", tmpl, "model", tt.model) if err := tmpl.Execute(actual, template.Values{Tools: tools, Messages: messages}); err != nil { t.Fatal(err) } - t.Log("actual", actual.String()) if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" { t.Errorf("mismatch (-got +want):\n%s", diff) } @@ -248,31 +299,53 @@ func TestParseToolCalls(t *testing.T) { tp := NewToolParser(m) got := []api.ToolCall{} success := false + var actualTokens strings.Builder + tokens := strings.Fields(tt.output) for _, tok := range tokens { + add := true s := " " + tok - var toolCalls []api.ToolCall - var ok bool + + // TODO(parthsareen): This logic is brittle as it mocks the logic in route, however can if tp.state != Done { - toolCalls, _, ok = tp.ParseToolCalls(s) - if ok { - success = true + toolCalls, leftover, ok := tp.ParseToolCalls(s) + if (tp.state == GreedyToolWithPrefix || tp.state == GreedyToolNoPrefix || tp.state == ToolSuffix) || (tp.state == ForceTools && len(toolCalls) == 0) { + continue } - got = append(got, toolCalls...) + if tp.state == ContainsPartialPrefix { + // actualTokens.Reset() + actualTokens.WriteString(leftover) + t.Log("leftover", leftover) + add = false + // continue + } + if ok && len(toolCalls) > 0 { + success = true + got = append(got, toolCalls...) + add = false + // actualTokens.Reset() + } + } + // s = strings.TrimSpace(s) + if add { + actualTokens.WriteString(s) } } if !tt.wantErr { - if diff := cmp.Diff(got, tt.expected); diff != "" { + if diff := cmp.Diff(got, tt.expectedToolCall); diff != "" { t.Errorf("mismatch (-got +want):\n%s", diff) } } if !success && !tt.wantErr { t.Errorf("expected success but got errors") } + stripped := strings.TrimSpace(actualTokens.String()) + if diff := cmp.Diff(stripped, tt.expectedTokens); diff != "" { + t.Log("actualTokens", stripped, "expectedTokens", tt.expectedTokens) + t.Errorf("tokens mismatch (-got +want):\n%s", diff) + } }) }) } } - -// TODO: add tests to check string sent not just tool