checkpoint for new parser

TODO:
- cleanup routes interface
- internal/external states
This commit is contained in:
ParthSareen 2025-05-07 19:35:11 -07:00
parent a44734b030
commit 6cb7494061
4 changed files with 231 additions and 101 deletions

2
go.mod
View File

@ -19,6 +19,7 @@ require (
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/dlclark/regexp2 v1.11.4
github.com/emirpasic/gods/v2 v2.0.0-alpha
github.com/go-json-experiment/json v0.0.0-20250417205406-170dfdcf87d1
github.com/google/go-cmp v0.6.0
github.com/mattn/go-runewidth v0.0.14
github.com/nlpodyssey/gopickle v0.3.0
@ -35,7 +36,6 @@ require (
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-json-experiment/json v0.0.0-20250417205406-170dfdcf87d1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/flatbuffers v24.3.25+incompatible // indirect
github.com/kr/text v0.2.0 // indirect

View File

@ -1519,14 +1519,28 @@ func (s *Server) ChatHandler(c *gin.Context) {
if len(req.Tools) > 0 && tp.state != Done {
fmt.Println("checking tool calls")
/*
This should give us a few return things we shouldnt have to build things up.
1. tool calls if any
2. leftover tokens if any - this happens in the partial case where we have a prefix inside a string
3. if we need to skip this loop and not send anything back
between these three things, we should just be switching on either the state or something to capture this
potentially consider a difference between internal and external state
*/
toolCalls, leftover, ok := tp.ParseToolCalls(r.Content)
// todo: this should just be one check/state coming back from the parse tool calls
if (tp.state == GreedyToolWithPrefix || tp.state == GreedyToolNoPrefix || tp.state == ToolSuffix) || (tp.state == ForceTools && len(toolCalls) == 0) {
return
}
// todo: our second state also should just be a var
if tp.state == ContainsPartialPrefix {
fmt.Println("sending tokens", leftover)
res.Message.Content = leftover
}
// TODO: this can be done inside the parse tool calls
if ok && len(toolCalls) > 0 {
res.Message.ToolCalls = toolCalls
for i := range toolCalls {

View File

@ -20,7 +20,6 @@ const (
SendTokens State = iota
GreedyToolWithPrefix
GreedyToolNoPrefix
// ToolCall
ForceTools
ToolSuffix
ContainsPartialPrefix
@ -56,7 +55,7 @@ type ToolParser struct {
}
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
// Returns parsed tool calls and a boolean indicating if the JSON is incomplete
// Returns parsed tool calls, a boolean indicating if the JSON is incomplete, and a boolean indicating if the tool calls were found
func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
var b bytes.Buffer
if err := p.tmpl.Execute(&b, map[string][]api.ToolCall{
@ -93,6 +92,9 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
for _, v := range o {
all = append(all, collect(v)...)
}
default:
// TODO: err or fallback
return nil
}
return all
@ -241,9 +243,37 @@ func (p *ToolParser) updateInputState(s string, hasPrefix bool) (string, bool) {
return s, true
}
func (p *ToolParser) sendTokens(original string, hasPrefix bool) (string, bool) {
if p.state == SendTokens {
return "", false
}
if p.state == ContainsPartialPrefix {
idx := strings.Index(original, p.toolPrefix)
if idx != -1 {
// still keeps the prefix
p.state = ContainsPartialPrefix
return original[:idx], false
} else {
fmt.Println("some weird state")
}
} else if strings.HasSuffix(original, p.toolPrefix[2:]) {
// can be with string or just the token
if hasPrefix {
original = strings.TrimSpace(original[:len(original)-(len(p.toolPrefix)+1)])
return original, false
} else {
p.state = ToolSuffix
p.sb.Reset()
}
}
return "", true
}
// ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing.
// Returns tool calls, whether parsing is incomplete, and any errors.
func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, string, bool) {
original := s
// append input
p.sb.WriteString(s)
s = p.sb.String()
@ -258,7 +288,20 @@ func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, string, bool) {
s, ok := p.updateInputState(s, hasPrefix)
if !ok {
if p.state == ContainsPartialPrefix {
return nil, s, false
idx := strings.Index(original, p.toolPrefix)
if idx != -1 {
// still keeps the prefix
p.state = ContainsPartialPrefix
// p.sb.Reset()
// p.sb.WriteString(original[idx:])
return nil, original[:idx], false
} else {
fmt.Println("some weird state")
}
// }
// s, ok = p.sendTokens(original, hasPrefix)
// if ok {
// return nil, s, true
}
return nil, "", false
}
@ -292,7 +335,7 @@ func NewToolParser(model *Model) *ToolParser {
} else {
state = GreedyToolWithPrefix
}
fmt.Println("state", state)
fmt.Println("setup state", state)
return &ToolParser{
tmpl: tmpl,
sb: &strings.Builder{},

View File

@ -48,40 +48,45 @@ func TestParseToolCalls(t *testing.T) {
}
cases := []struct {
name string
model string
output string
expected []api.ToolCall
wantErr bool
name string
model string
output string
expectedToolCall []api.ToolCall
expectedTokens string
wantErr bool
}{
{
name: "mistral invalid json",
model: "mistral",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`,
expected: []api.ToolCall{},
wantErr: true,
name: "mistral invalid json",
model: "mistral",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`,
expectedToolCall: []api.ToolCall{},
expectedTokens: "",
wantErr: true,
},
{
name: "mistral multiple tool calls - no prefix",
model: "mistral",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expected: []api.ToolCall{t1, t2},
wantErr: false,
name: "mistral multiple tool calls - no prefix",
model: "mistral",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
wantErr: false,
},
{
name: "mistral tool calls with text in between - no prefix",
model: "mistral",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expected: []api.ToolCall{t1, t2},
wantErr: false,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: `model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
wantErr: false,
},
{
name: "mistral valid json - with prefix",
model: "mistral",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expected: []api.ToolCall{t1, t2},
wantErr: false,
name: "mistral valid json - with prefix",
model: "mistral",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
wantErr: false,
},
{
// In this case we'd be ignoring the text in between and just returning the tool calls
@ -89,15 +94,17 @@ func TestParseToolCalls(t *testing.T) {
model: "mistral",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expected: []api.ToolCall{t1, t2, t1, t2},
wantErr: false,
expectedToolCall: []api.ToolCall{t1, t2, t1, t2},
expectedTokens: "",
wantErr: false,
},
{
name: "mistral incomplete json",
model: "mistral",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `,
expected: []api.ToolCall{},
wantErr: true,
name: "mistral incomplete json",
model: "mistral",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `,
expectedToolCall: []api.ToolCall{},
expectedTokens: "",
wantErr: true,
},
{
name: "mistral without tool token",
@ -105,15 +112,17 @@ func TestParseToolCalls(t *testing.T) {
output: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expected: []api.ToolCall{},
wantErr: true,
expectedToolCall: []api.ToolCall{},
expectedTokens: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
wantErr: true,
},
{
name: "mistral without tool token - tool first",
model: "mistral",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expected: []api.ToolCall{t1, t2},
wantErr: false,
name: "mistral without tool token - tool first",
model: "mistral",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
wantErr: false,
},
{
name: "command-r-plus with json block",
@ -136,15 +145,17 @@ func TestParseToolCalls(t *testing.T) {
}
]
` + "```",
expected: []api.ToolCall{t1, t2},
wantErr: false,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
wantErr: false,
},
{
name: "firefunction with functools",
model: "firefunction",
output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expected: []api.ToolCall{t1, t2},
wantErr: false,
name: "firefunction with functools",
model: "firefunction",
output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
wantErr: false,
},
{
name: "llama3 with tool call tags",
@ -152,64 +163,106 @@ func TestParseToolCalls(t *testing.T) {
output: `<tool_call>
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
</tool_call>`,
expected: []api.ToolCall{t1},
wantErr: false,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "",
wantErr: false,
},
{
name: "xlam with tool_calls wrapper",
model: "xlam",
output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`,
expected: []api.ToolCall{t1, t2},
wantErr: false,
name: "xlam with tool_calls wrapper",
model: "xlam",
output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
wantErr: false,
},
{
name: "qwen with single tool call",
model: "qwen2.5-coder",
output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
expected: []api.ToolCall{t1},
wantErr: false,
name: "qwen2.5 with single tool call",
model: "qwen2.5-coder",
output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "",
wantErr: false,
},
{
name: "qwen with invalid tool token",
model: "qwen2.5-coder",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expected: []api.ToolCall{t1, t2},
wantErr: false,
name: "qwen with invalid tool token",
model: "qwen2.5-coder",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
wantErr: false,
},
{
name: "qwen3 with single tool call and thinking",
model: "qwen3",
output: `<think>Okay, let me think what tool we should use...</think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
expected: []api.ToolCall{t1},
wantErr: false,
// tests the leftover logic as well
name: "qwen3 with single tool call and thinking",
model: "qwen3",
output: `<think>Okay, let me think what tool we should use...</think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "<think>Okay, let me think what tool we should use...</think>",
wantErr: false,
},
{
name: "qwen3 with single tool call and thinking spaces",
model: "qwen3",
output: `<think>Okay, let me think what tool we should use...</think> <tool_call> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
expected: []api.ToolCall{t1},
wantErr: false,
name: "qwen3 with single tool call and thinking spaces",
model: "qwen3",
output: `<think>Okay, let me think what tool we should use...</think> <tool_call> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "<think>Okay, let me think what tool we should use...</think>",
wantErr: false,
},
// {
// name: "qwen3 testing",
// model: "qwen3",
// output: `<think></think>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
// expectedToolCall: []api.ToolCall{},
// expectedTokens: `<think></think>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
// wantErr: true,
// },
// {
// name: "qwen3 testing 2",
// model: "qwen3",
// output: `<think></think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
// expectedToolCall: []api.ToolCall{t1},
// expectedTokens: `<think></think>`,
// wantErr: true,
// },
{
name: "qwen with no tool calls",
model: "qwen2.5-coder",
output: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
expectedToolCall: []api.ToolCall{},
expectedTokens: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
wantErr: true,
},
{
name: "qwen with no tool calls",
model: "qwen2.5-coder",
output: " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
expected: []api.ToolCall{},
wantErr: true,
name: "llama3.2 with tool call - no prefix",
model: "llama3.2",
output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "",
wantErr: false,
},
{
name: "llama3.2 with tool call - no prefix",
model: "llama3.2",
output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
expected: []api.ToolCall{t1},
wantErr: false,
name: "llama3.2 with incomplete tool call - no prefix",
model: "llama3.2",
output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, `,
expectedToolCall: []api.ToolCall{},
expectedTokens: "",
wantErr: true,
},
{
name: "llama3.2 with tool call - in middle",
model: "llama3.2",
output: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
expected: []api.ToolCall{},
wantErr: true,
name: "llama3.2 with tool call - in middle",
model: "llama3.2",
output: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
expectedToolCall: []api.ToolCall{},
expectedTokens: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
wantErr: true,
},
{
name: "llama3.2 - fake tool prefix",
model: "llama3.2",
output: `<tool_call>{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
expectedToolCall: []api.ToolCall{},
expectedTokens: `<tool_call>{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
wantErr: true,
},
}
@ -232,12 +285,10 @@ func TestParseToolCalls(t *testing.T) {
t.Run("template", func(t *testing.T) {
actual := &bytes.Buffer{} // Create new buffer for each test
t.Log("template", tmpl, "model", tt.model)
if err := tmpl.Execute(actual, template.Values{Tools: tools, Messages: messages}); err != nil {
t.Fatal(err)
}
t.Log("actual", actual.String())
if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
@ -248,31 +299,53 @@ func TestParseToolCalls(t *testing.T) {
tp := NewToolParser(m)
got := []api.ToolCall{}
success := false
var actualTokens strings.Builder
tokens := strings.Fields(tt.output)
for _, tok := range tokens {
add := true
s := " " + tok
var toolCalls []api.ToolCall
var ok bool
// TODO(parthsareen): This logic is brittle as it mocks the logic in route, however can
if tp.state != Done {
toolCalls, _, ok = tp.ParseToolCalls(s)
if ok {
success = true
toolCalls, leftover, ok := tp.ParseToolCalls(s)
if (tp.state == GreedyToolWithPrefix || tp.state == GreedyToolNoPrefix || tp.state == ToolSuffix) || (tp.state == ForceTools && len(toolCalls) == 0) {
continue
}
got = append(got, toolCalls...)
if tp.state == ContainsPartialPrefix {
// actualTokens.Reset()
actualTokens.WriteString(leftover)
t.Log("leftover", leftover)
add = false
// continue
}
if ok && len(toolCalls) > 0 {
success = true
got = append(got, toolCalls...)
add = false
// actualTokens.Reset()
}
}
// s = strings.TrimSpace(s)
if add {
actualTokens.WriteString(s)
}
}
if !tt.wantErr {
if diff := cmp.Diff(got, tt.expected); diff != "" {
if diff := cmp.Diff(got, tt.expectedToolCall); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
if !success && !tt.wantErr {
t.Errorf("expected success but got errors")
}
stripped := strings.TrimSpace(actualTokens.String())
if diff := cmp.Diff(stripped, tt.expectedTokens); diff != "" {
t.Log("actualTokens", stripped, "expectedTokens", tt.expectedTokens)
t.Errorf("tokens mismatch (-got +want):\n%s", diff)
}
})
})
}
}
// TODO: add tests to check string sent not just tool