checkpoint for new parser

TODO:
- cleanup routes interface
- internal/external states
This commit is contained in:
ParthSareen 2025-05-07 19:35:11 -07:00
parent a44734b030
commit 6cb7494061
4 changed files with 231 additions and 101 deletions

2
go.mod
View File

@ -19,6 +19,7 @@ require (
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/dlclark/regexp2 v1.11.4 github.com/dlclark/regexp2 v1.11.4
github.com/emirpasic/gods/v2 v2.0.0-alpha github.com/emirpasic/gods/v2 v2.0.0-alpha
github.com/go-json-experiment/json v0.0.0-20250417205406-170dfdcf87d1
github.com/google/go-cmp v0.6.0 github.com/google/go-cmp v0.6.0
github.com/mattn/go-runewidth v0.0.14 github.com/mattn/go-runewidth v0.0.14
github.com/nlpodyssey/gopickle v0.3.0 github.com/nlpodyssey/gopickle v0.3.0
@ -35,7 +36,6 @@ require (
github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect github.com/cloudwego/iasm v0.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-json-experiment/json v0.0.0-20250417205406-170dfdcf87d1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/flatbuffers v24.3.25+incompatible // indirect github.com/google/flatbuffers v24.3.25+incompatible // indirect
github.com/kr/text v0.2.0 // indirect github.com/kr/text v0.2.0 // indirect

View File

@ -1519,14 +1519,28 @@ func (s *Server) ChatHandler(c *gin.Context) {
if len(req.Tools) > 0 && tp.state != Done { if len(req.Tools) > 0 && tp.state != Done {
fmt.Println("checking tool calls") fmt.Println("checking tool calls")
/*
This should give us a few return things we shouldnt have to build things up.
1. tool calls if any
2. leftover tokens if any - this happens in the partial case where we have a prefix inside a string
3. if we need to skip this loop and not send anything back
between these three things, we should just be switching on either the state or something to capture this
potentially consider a difference between internal and external state
*/
toolCalls, leftover, ok := tp.ParseToolCalls(r.Content) toolCalls, leftover, ok := tp.ParseToolCalls(r.Content)
// todo: this should just be one check/state coming back from the parse tool calls
if (tp.state == GreedyToolWithPrefix || tp.state == GreedyToolNoPrefix || tp.state == ToolSuffix) || (tp.state == ForceTools && len(toolCalls) == 0) { if (tp.state == GreedyToolWithPrefix || tp.state == GreedyToolNoPrefix || tp.state == ToolSuffix) || (tp.state == ForceTools && len(toolCalls) == 0) {
return return
} }
// todo: our second state also should just be a var
if tp.state == ContainsPartialPrefix { if tp.state == ContainsPartialPrefix {
fmt.Println("sending tokens", leftover) fmt.Println("sending tokens", leftover)
res.Message.Content = leftover res.Message.Content = leftover
} }
// TODO: this can be done inside the parse tool calls
if ok && len(toolCalls) > 0 { if ok && len(toolCalls) > 0 {
res.Message.ToolCalls = toolCalls res.Message.ToolCalls = toolCalls
for i := range toolCalls { for i := range toolCalls {

View File

@ -20,7 +20,6 @@ const (
SendTokens State = iota SendTokens State = iota
GreedyToolWithPrefix GreedyToolWithPrefix
GreedyToolNoPrefix GreedyToolNoPrefix
// ToolCall
ForceTools ForceTools
ToolSuffix ToolSuffix
ContainsPartialPrefix ContainsPartialPrefix
@ -56,7 +55,7 @@ type ToolParser struct {
} }
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls. // parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
// Returns parsed tool calls and a boolean indicating if the JSON is incomplete // Returns parsed tool calls, a boolean indicating if the JSON is incomplete, and a boolean indicating if the tool calls were found
func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) { func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
var b bytes.Buffer var b bytes.Buffer
if err := p.tmpl.Execute(&b, map[string][]api.ToolCall{ if err := p.tmpl.Execute(&b, map[string][]api.ToolCall{
@ -93,6 +92,9 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
for _, v := range o { for _, v := range o {
all = append(all, collect(v)...) all = append(all, collect(v)...)
} }
default:
// TODO: err or fallback
return nil
} }
return all return all
@ -241,9 +243,37 @@ func (p *ToolParser) updateInputState(s string, hasPrefix bool) (string, bool) {
return s, true return s, true
} }
func (p *ToolParser) sendTokens(original string, hasPrefix bool) (string, bool) {
if p.state == SendTokens {
return "", false
}
if p.state == ContainsPartialPrefix {
idx := strings.Index(original, p.toolPrefix)
if idx != -1 {
// still keeps the prefix
p.state = ContainsPartialPrefix
return original[:idx], false
} else {
fmt.Println("some weird state")
}
} else if strings.HasSuffix(original, p.toolPrefix[2:]) {
// can be with string or just the token
if hasPrefix {
original = strings.TrimSpace(original[:len(original)-(len(p.toolPrefix)+1)])
return original, false
} else {
p.state = ToolSuffix
p.sb.Reset()
}
}
return "", true
}
// ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing. // ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing.
// Returns tool calls, whether parsing is incomplete, and any errors. // Returns tool calls, whether parsing is incomplete, and any errors.
func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, string, bool) { func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, string, bool) {
original := s
// append input // append input
p.sb.WriteString(s) p.sb.WriteString(s)
s = p.sb.String() s = p.sb.String()
@ -258,7 +288,20 @@ func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, string, bool) {
s, ok := p.updateInputState(s, hasPrefix) s, ok := p.updateInputState(s, hasPrefix)
if !ok { if !ok {
if p.state == ContainsPartialPrefix { if p.state == ContainsPartialPrefix {
return nil, s, false idx := strings.Index(original, p.toolPrefix)
if idx != -1 {
// still keeps the prefix
p.state = ContainsPartialPrefix
// p.sb.Reset()
// p.sb.WriteString(original[idx:])
return nil, original[:idx], false
} else {
fmt.Println("some weird state")
}
// }
// s, ok = p.sendTokens(original, hasPrefix)
// if ok {
// return nil, s, true
} }
return nil, "", false return nil, "", false
} }
@ -292,7 +335,7 @@ func NewToolParser(model *Model) *ToolParser {
} else { } else {
state = GreedyToolWithPrefix state = GreedyToolWithPrefix
} }
fmt.Println("state", state) fmt.Println("setup state", state)
return &ToolParser{ return &ToolParser{
tmpl: tmpl, tmpl: tmpl,
sb: &strings.Builder{}, sb: &strings.Builder{},

View File

@ -48,40 +48,45 @@ func TestParseToolCalls(t *testing.T) {
} }
cases := []struct { cases := []struct {
name string name string
model string model string
output string output string
expected []api.ToolCall expectedToolCall []api.ToolCall
wantErr bool expectedTokens string
wantErr bool
}{ }{
{ {
name: "mistral invalid json", name: "mistral invalid json",
model: "mistral", model: "mistral",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`, output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`,
expected: []api.ToolCall{}, expectedToolCall: []api.ToolCall{},
wantErr: true, expectedTokens: "",
wantErr: true,
}, },
{ {
name: "mistral multiple tool calls - no prefix", name: "mistral multiple tool calls - no prefix",
model: "mistral", model: "mistral",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expected: []api.ToolCall{t1, t2}, expectedToolCall: []api.ToolCall{t1, t2},
wantErr: false, expectedTokens: "",
wantErr: false,
}, },
{ {
name: "mistral tool calls with text in between - no prefix", name: "mistral tool calls with text in between - no prefix",
model: "mistral", model: "mistral",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expected: []api.ToolCall{t1, t2}, expectedToolCall: []api.ToolCall{t1, t2},
wantErr: false, expectedTokens: `model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
wantErr: false,
}, },
{ {
name: "mistral valid json - with prefix", name: "mistral valid json - with prefix",
model: "mistral", model: "mistral",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expected: []api.ToolCall{t1, t2}, expectedToolCall: []api.ToolCall{t1, t2},
wantErr: false, expectedTokens: "",
wantErr: false,
}, },
{ {
// In this case we'd be ignoring the text in between and just returning the tool calls // In this case we'd be ignoring the text in between and just returning the tool calls
@ -89,15 +94,17 @@ func TestParseToolCalls(t *testing.T) {
model: "mistral", model: "mistral",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expected: []api.ToolCall{t1, t2, t1, t2}, expectedToolCall: []api.ToolCall{t1, t2, t1, t2},
wantErr: false, expectedTokens: "",
wantErr: false,
}, },
{ {
name: "mistral incomplete json", name: "mistral incomplete json",
model: "mistral", model: "mistral",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `, output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `,
expected: []api.ToolCall{}, expectedToolCall: []api.ToolCall{},
wantErr: true, expectedTokens: "",
wantErr: true,
}, },
{ {
name: "mistral without tool token", name: "mistral without tool token",
@ -105,15 +112,17 @@ func TestParseToolCalls(t *testing.T) {
output: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: output: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expected: []api.ToolCall{}, expectedToolCall: []api.ToolCall{},
wantErr: true, expectedTokens: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
wantErr: true,
}, },
{ {
name: "mistral without tool token - tool first", name: "mistral without tool token - tool first",
model: "mistral", model: "mistral",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expected: []api.ToolCall{t1, t2}, expectedToolCall: []api.ToolCall{t1, t2},
wantErr: false, expectedTokens: "",
wantErr: false,
}, },
{ {
name: "command-r-plus with json block", name: "command-r-plus with json block",
@ -136,15 +145,17 @@ func TestParseToolCalls(t *testing.T) {
} }
] ]
` + "```", ` + "```",
expected: []api.ToolCall{t1, t2}, expectedToolCall: []api.ToolCall{t1, t2},
wantErr: false, expectedTokens: "",
wantErr: false,
}, },
{ {
name: "firefunction with functools", name: "firefunction with functools",
model: "firefunction", model: "firefunction",
output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expected: []api.ToolCall{t1, t2}, expectedToolCall: []api.ToolCall{t1, t2},
wantErr: false, expectedTokens: "",
wantErr: false,
}, },
{ {
name: "llama3 with tool call tags", name: "llama3 with tool call tags",
@ -152,64 +163,106 @@ func TestParseToolCalls(t *testing.T) {
output: `<tool_call> output: `<tool_call>
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
</tool_call>`, </tool_call>`,
expected: []api.ToolCall{t1}, expectedToolCall: []api.ToolCall{t1},
wantErr: false, expectedTokens: "",
wantErr: false,
}, },
{ {
name: "xlam with tool_calls wrapper", name: "xlam with tool_calls wrapper",
model: "xlam", model: "xlam",
output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`,
expected: []api.ToolCall{t1, t2}, expectedToolCall: []api.ToolCall{t1, t2},
wantErr: false, expectedTokens: "",
wantErr: false,
}, },
{ {
name: "qwen with single tool call", name: "qwen2.5 with single tool call",
model: "qwen2.5-coder", model: "qwen2.5-coder",
output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`, output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
expected: []api.ToolCall{t1}, expectedToolCall: []api.ToolCall{t1},
wantErr: false, expectedTokens: "",
wantErr: false,
}, },
{ {
name: "qwen with invalid tool token", name: "qwen with invalid tool token",
model: "qwen2.5-coder", model: "qwen2.5-coder",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expected: []api.ToolCall{t1, t2}, expectedToolCall: []api.ToolCall{t1, t2},
wantErr: false, expectedTokens: "",
wantErr: false,
}, },
{ {
name: "qwen3 with single tool call and thinking", // tests the leftover logic as well
model: "qwen3", name: "qwen3 with single tool call and thinking",
output: `<think>Okay, let me think what tool we should use...</think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`, model: "qwen3",
expected: []api.ToolCall{t1}, output: `<think>Okay, let me think what tool we should use...</think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
wantErr: false, expectedToolCall: []api.ToolCall{t1},
expectedTokens: "<think>Okay, let me think what tool we should use...</think>",
wantErr: false,
}, },
{ {
name: "qwen3 with single tool call and thinking spaces", name: "qwen3 with single tool call and thinking spaces",
model: "qwen3", model: "qwen3",
output: `<think>Okay, let me think what tool we should use...</think> <tool_call> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`, output: `<think>Okay, let me think what tool we should use...</think> <tool_call> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
expected: []api.ToolCall{t1}, expectedToolCall: []api.ToolCall{t1},
wantErr: false, expectedTokens: "<think>Okay, let me think what tool we should use...</think>",
wantErr: false,
},
// {
// name: "qwen3 testing",
// model: "qwen3",
// output: `<think></think>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
// expectedToolCall: []api.ToolCall{},
// expectedTokens: `<think></think>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
// wantErr: true,
// },
// {
// name: "qwen3 testing 2",
// model: "qwen3",
// output: `<think></think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
// expectedToolCall: []api.ToolCall{t1},
// expectedTokens: `<think></think>`,
// wantErr: true,
// },
{
name: "qwen with no tool calls",
model: "qwen2.5-coder",
output: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
expectedToolCall: []api.ToolCall{},
expectedTokens: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
wantErr: true,
}, },
{ {
name: "qwen with no tool calls", name: "llama3.2 with tool call - no prefix",
model: "qwen2.5-coder", model: "llama3.2",
output: " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
expected: []api.ToolCall{}, expectedToolCall: []api.ToolCall{t1},
wantErr: true, expectedTokens: "",
wantErr: false,
}, },
{ {
name: "llama3.2 with tool call - no prefix", name: "llama3.2 with incomplete tool call - no prefix",
model: "llama3.2", model: "llama3.2",
output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`, output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, `,
expected: []api.ToolCall{t1}, expectedToolCall: []api.ToolCall{},
wantErr: false, expectedTokens: "",
wantErr: true,
}, },
{ {
name: "llama3.2 with tool call - in middle", name: "llama3.2 with tool call - in middle",
model: "llama3.2", model: "llama3.2",
output: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`, output: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
expected: []api.ToolCall{}, expectedToolCall: []api.ToolCall{},
wantErr: true, expectedTokens: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
wantErr: true,
},
{
name: "llama3.2 - fake tool prefix",
model: "llama3.2",
output: `<tool_call>{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
expectedToolCall: []api.ToolCall{},
expectedTokens: `<tool_call>{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
wantErr: true,
}, },
} }
@ -232,12 +285,10 @@ func TestParseToolCalls(t *testing.T) {
t.Run("template", func(t *testing.T) { t.Run("template", func(t *testing.T) {
actual := &bytes.Buffer{} // Create new buffer for each test actual := &bytes.Buffer{} // Create new buffer for each test
t.Log("template", tmpl, "model", tt.model)
if err := tmpl.Execute(actual, template.Values{Tools: tools, Messages: messages}); err != nil { if err := tmpl.Execute(actual, template.Values{Tools: tools, Messages: messages}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Log("actual", actual.String())
if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" { if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
@ -248,31 +299,53 @@ func TestParseToolCalls(t *testing.T) {
tp := NewToolParser(m) tp := NewToolParser(m)
got := []api.ToolCall{} got := []api.ToolCall{}
success := false success := false
var actualTokens strings.Builder
tokens := strings.Fields(tt.output) tokens := strings.Fields(tt.output)
for _, tok := range tokens { for _, tok := range tokens {
add := true
s := " " + tok s := " " + tok
var toolCalls []api.ToolCall
var ok bool // TODO(parthsareen): This logic is brittle as it mocks the logic in route, however can
if tp.state != Done { if tp.state != Done {
toolCalls, _, ok = tp.ParseToolCalls(s) toolCalls, leftover, ok := tp.ParseToolCalls(s)
if ok { if (tp.state == GreedyToolWithPrefix || tp.state == GreedyToolNoPrefix || tp.state == ToolSuffix) || (tp.state == ForceTools && len(toolCalls) == 0) {
success = true continue
} }
got = append(got, toolCalls...) if tp.state == ContainsPartialPrefix {
// actualTokens.Reset()
actualTokens.WriteString(leftover)
t.Log("leftover", leftover)
add = false
// continue
}
if ok && len(toolCalls) > 0 {
success = true
got = append(got, toolCalls...)
add = false
// actualTokens.Reset()
}
}
// s = strings.TrimSpace(s)
if add {
actualTokens.WriteString(s)
} }
} }
if !tt.wantErr { if !tt.wantErr {
if diff := cmp.Diff(got, tt.expected); diff != "" { if diff := cmp.Diff(got, tt.expectedToolCall); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
} }
if !success && !tt.wantErr { if !success && !tt.wantErr {
t.Errorf("expected success but got errors") t.Errorf("expected success but got errors")
} }
stripped := strings.TrimSpace(actualTokens.String())
if diff := cmp.Diff(stripped, tt.expectedTokens); diff != "" {
t.Log("actualTokens", stripped, "expectedTokens", tt.expectedTokens)
t.Errorf("tokens mismatch (-got +want):\n%s", diff)
}
}) })
}) })
} }
} }
// TODO: add tests to check string sent not just tool