checkpoint - cleanup still left, functionality setup

This commit is contained in:
ParthSareen 2025-05-08 18:48:44 -07:00
parent 6cb7494061
commit 779547fcde
3 changed files with 177 additions and 183 deletions

View File

@ -1486,10 +1486,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
go func() { go func() {
defer close(ch) defer close(ch)
// var sb strings.Builder // var sb strings.Builder
var toolCallIndex int = 0 var toolParser *ToolParser
var tp *ToolParser
if len(req.Tools) > 0 { if len(req.Tools) > 0 {
tp = NewToolParser(m) toolParser = NewToolParser(m)
} }
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
@ -1517,37 +1516,19 @@ func (s *Server) ChatHandler(c *gin.Context) {
res.LoadDuration = checkpointLoaded.Sub(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
} }
if len(req.Tools) > 0 && tp.state != Done { if len(req.Tools) > 0 && !toolParser.Done {
fmt.Println("checking tool calls") toolCalls, leftover := toolParser.ParseToolCalls(r.Content)
/* switch toolParser.ParserState {
This should give us a few return things we shouldnt have to build things up. case ToolCallAccumulate:
1. tool calls if any // tokens are accumulated in the tool parser
2. leftover tokens if any - this happens in the partial case where we have a prefix inside a string
3. if we need to skip this loop and not send anything back
between these three things, we should just be switching on either the state or something to capture this
potentially consider a difference between internal and external state
*/
toolCalls, leftover, ok := tp.ParseToolCalls(r.Content)
// todo: this should just be one check/state coming back from the parse tool calls
if (tp.state == GreedyToolWithPrefix || tp.state == GreedyToolNoPrefix || tp.state == ToolSuffix) || (tp.state == ForceTools && len(toolCalls) == 0) {
return return
} case ToolCallSendTokens:
// todo: our second state also should just be a var // tokens are sent back in the response
if tp.state == ContainsPartialPrefix { case ToolCallSendPartial:
fmt.Println("sending tokens", leftover) // tokens not needed for parsing are sent back in the response
res.Message.Content = leftover res.Message.Content = leftover
} case ToolCallFound:
// TODO: this can be done inside the parse tool calls
if ok && len(toolCalls) > 0 {
res.Message.ToolCalls = toolCalls res.Message.ToolCalls = toolCalls
for i := range toolCalls {
toolCalls[i].Function.Index = toolCallIndex
toolCallIndex++
}
// Remove content when tool call is present
res.Message.Content = "" res.Message.Content = ""
} }
} }

View File

@ -10,12 +10,14 @@ import (
gotmpl "text/template" gotmpl "text/template"
jsonv2 "github.com/go-json-experiment/json" jsonv2 "github.com/go-json-experiment/json"
jsontext "github.com/go-json-experiment/json/jsontext"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
type State int type State int
// TODO: potentially coalesce states
const ( const (
SendTokens State = iota SendTokens State = iota
GreedyToolWithPrefix GreedyToolWithPrefix
@ -26,6 +28,30 @@ const (
Done Done
) )
type ExternalState int
const (
ToolCallFound ExternalState = iota
ToolCallSendPartial
ToolCallAccumulate
ToolCallSendTokens
)
func (s ExternalState) String() string {
switch s {
case ToolCallFound:
return "ToolCallFound"
case ToolCallSendPartial:
return "ToolCallSendPartial"
case ToolCallAccumulate:
return "ToolCallAccumulate"
case ToolCallSendTokens:
return "ToolCallSendTokens"
default:
return fmt.Sprintf("Unknown ExternalState (%d)", s)
}
}
func (s State) String() string { func (s State) String() string {
switch s { switch s {
case SendTokens: case SendTokens:
@ -47,13 +73,18 @@ func (s State) String() string {
} }
} }
// TODO: simplify if possible
type ToolParser struct { type ToolParser struct {
tmpl *gotmpl.Template tmpl *gotmpl.Template
state State state State
sb *strings.Builder sb *strings.Builder
toolPrefix string toolPrefix string
toolIndex int
ParserState ExternalState
Done bool
} }
// ? move to a separate file
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls. // parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
// Returns parsed tool calls, a boolean indicating if the JSON is incomplete, and a boolean indicating if the tool calls were found // Returns parsed tool calls, a boolean indicating if the JSON is incomplete, and a boolean indicating if the tool calls were found
func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) { func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
@ -73,7 +104,7 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
return nil, false, false return nil, false, false
} }
// ! this can be either a map or an array // this can be either a map or an array
var temp any var temp any
err := jsonv2.Unmarshal(b.Bytes(), &temp) err := jsonv2.Unmarshal(b.Bytes(), &temp)
if err != nil { if err != nil {
@ -128,6 +159,14 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
if name == "" || arguments == "" { if name == "" || arguments == "" {
return nil, false, false return nil, false, false
} }
// TODO: there is probably some underlying repeat work here to avoid
// This incrementally decodes the JSON string and returns the first parsedobject
dec := jsontext.NewDecoder(strings.NewReader(s))
if got, err := dec.ReadValue(); err == nil {
s = got.String()
}
var responseObjects any var responseObjects any
err = jsonv2.Unmarshal([]byte(s), &responseObjects) err = jsonv2.Unmarshal([]byte(s), &responseObjects)
if err != nil { if err != nil {
@ -137,7 +176,7 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
return nil, true, false return nil, true, false
} else { } else {
fmt.Printf("Other error: %v\n", err) fmt.Printf("Other error: %v\n", err)
fmt.Println("exiting", p.state) fmt.Println("exiting from JSON parsing", p.state)
return nil, false, false return nil, false, false
} }
} }
@ -182,10 +221,14 @@ func (p *ToolParser) updateOutputState(ok bool, partial bool, tcs []api.ToolCall
if p.state == GreedyToolNoPrefix { if p.state == GreedyToolNoPrefix {
fmt.Println(" Subcase: GreedyToolNoPrefix - marking as done") fmt.Println(" Subcase: GreedyToolNoPrefix - marking as done")
p.state = Done p.state = Done
// p.ParserState = DoneFR
p.ParserState = ToolCallSendTokens
p.Done = true
} }
if p.state == GreedyToolWithPrefix { if p.state == GreedyToolWithPrefix {
fmt.Println(" Subcase: GreedyToolWithPrefix - switching to SendTokens") fmt.Println(" Subcase: GreedyToolWithPrefix - switching to SendTokens")
p.state = SendTokens p.state = SendTokens
p.ParserState = ToolCallSendTokens
} }
p.sb.Reset() p.sb.Reset()
case !ok && partial: case !ok && partial:
@ -197,131 +240,118 @@ func (p *ToolParser) updateOutputState(ok bool, partial bool, tcs []api.ToolCall
// do not parse again in the greedy JSON case as soon as we have a tool call // do not parse again in the greedy JSON case as soon as we have a tool call
if p.state == GreedyToolWithPrefix { if p.state == GreedyToolWithPrefix {
p.state = SendTokens p.state = SendTokens
p.ParserState = ToolCallFound
p.state = Done
p.Done = true
} else if p.state == GreedyToolNoPrefix { } else if p.state == GreedyToolNoPrefix {
fmt.Println(" Subcase: Greedy modes - marking done and switching to SendTokens") fmt.Println(" Subcase: Greedy modes - marking done and switching to SendTokens")
p.state = Done p.state = Done
p.Done = true
} }
p.sb.Reset() p.sb.Reset()
} }
p.updateExternalState(tcs)
} }
func (p *ToolParser) updateInputState(s string, hasPrefix bool) (string, bool) { func (p *ToolParser) updateExternalState(tcs []api.ToolCall) {
if (p.state == GreedyToolWithPrefix || p.state == GreedyToolNoPrefix || p.state == ToolSuffix) || (p.state == ForceTools && len(tcs) == 0) {
p.ParserState = ToolCallAccumulate
} else if p.state == ContainsPartialPrefix {
p.ParserState = ToolCallSendPartial
} else if len(tcs) > 0 {
p.ParserState = ToolCallFound
} else if p.state == SendTokens {
p.ParserState = ToolCallSendTokens
}
}
// string, and if it has a prefix
func (p *ToolParser) checkPrefix(s string) (string, bool) {
if p.toolPrefix == "" { if p.toolPrefix == "" {
return s, true return s, true
} }
original := s
// s = strings.TrimSpace(s)
s, hasPrefix := strings.CutPrefix(s, p.toolPrefix)
if hasPrefix { if hasPrefix {
fmt.Println("has prefix", s)
p.state = ForceTools p.state = ForceTools
// partial tool possibly // partial tool possibly
} else if strings.HasPrefix(p.toolPrefix, s) { } else if strings.HasPrefix(p.toolPrefix, s) {
slog.Debug("tool prefix partially", "prefix", p.toolPrefix, "content", s) slog.Debug("tool prefix partially", "prefix", p.toolPrefix, "content", s)
// TODO: could possibly err maybe this should be greedy instead? // TODO: could possibly err maybe this should be greedy instead?
p.state = ForceTools p.state = ForceTools
// this would basically be a no op on rest of the input
return "", false return "", false
} else if strings.Contains(s, p.toolPrefix) { // the case where "token<tool_call>" - send "token" back
idx := strings.Index(s, p.toolPrefix) // accounts for spaces in prefix or suffix to avoid breaking cache
if idx != -1 { } else if strings.Contains(original, p.toolPrefix) {
// still keeps the prefix
p.state = ContainsPartialPrefix
p.sb.Reset()
p.sb.WriteString(s[idx:])
return s[:idx], false
}
}
// Special token end case
if strings.HasSuffix(s, p.toolPrefix[2:]) {
// can be with string or just the token
if hasPrefix {
s = strings.TrimSpace(s[:len(s)-(len(p.toolPrefix)+1)])
} else {
p.state = ToolSuffix
p.sb.Reset()
return "", false
}
slog.Debug("setting to no tool", "content", s)
}
return s, true
}
func (p *ToolParser) sendTokens(original string, hasPrefix bool) (string, bool) {
if p.state == SendTokens {
return "", false
}
if p.state == ContainsPartialPrefix {
idx := strings.Index(original, p.toolPrefix) idx := strings.Index(original, p.toolPrefix)
if idx != -1 { if idx != -1 {
// still keeps the prefix // still keeps the prefix
p.state = ContainsPartialPrefix p.state = ContainsPartialPrefix
return original[:idx], false
} else {
fmt.Println("some weird state")
}
} else if strings.HasSuffix(original, p.toolPrefix[2:]) {
// can be with string or just the token
if hasPrefix {
original = strings.TrimSpace(original[:len(original)-(len(p.toolPrefix)+1)])
return original, false
} else {
p.state = ToolSuffix
p.sb.Reset() p.sb.Reset()
// todo: see if there is a simpler way for this
idx2 := strings.Index(s, p.toolPrefix)
p.sb.WriteString(s[idx2:])
return original[:idx], false
} }
} }
return "", true return s, true
} }
// TODO: simplify the flow of this function
// ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing. // ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing.
// Returns tool calls, whether parsing is incomplete, and any errors. // Returns tool calls, whether parsing is incomplete, and any errors.
func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, string, bool) { func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, string) {
original := s fmt.Println("checking tool calls", s)
// append input fmt.Println("external state", p.ParserState)
fmt.Println("internal state", p.state)
p.sb.WriteString(s) p.sb.WriteString(s)
s = p.sb.String() s = p.sb.String()
s = strings.TrimSpace(s) s = strings.TrimSpace(s)
fmt.Println("sb", s)
p.updateExternalState(nil)
if len(s) == 0 { if len(s) == 0 {
return nil, "", false return nil, ""
} }
s, hasPrefix := strings.CutPrefix(s, p.toolPrefix) s, cont := p.checkPrefix(s)
if !cont {
s, ok := p.updateInputState(s, hasPrefix) p.updateExternalState(nil)
if !ok {
if p.state == ContainsPartialPrefix { if p.state == ContainsPartialPrefix {
idx := strings.Index(original, p.toolPrefix) return nil, s
if idx != -1 {
// still keeps the prefix
p.state = ContainsPartialPrefix
// p.sb.Reset()
// p.sb.WriteString(original[idx:])
return nil, original[:idx], false
} else {
fmt.Println("some weird state")
}
// }
// s, ok = p.sendTokens(original, hasPrefix)
// if ok {
// return nil, s, true
} }
return nil, "", false return nil, ""
} }
// stay in SendTokens unless we have a prefix
if p.state == SendTokens { if p.state == SendTokens {
return nil, "", false fmt.Println("SendTokens - resetting buffer")
p.updateExternalState(nil)
p.sb.Reset()
return nil, ""
} }
var tcs []api.ToolCall tcs, partial, ok := p.parseJSONToolCalls(s)
var partial bool
tcs, partial, ok = p.parseJSONToolCalls(s)
p.updateOutputState(ok, partial, tcs) p.updateOutputState(ok, partial, tcs)
fmt.Println("output state", p.ParserState, p.state)
if !ok { if !ok {
return nil, "", false fmt.Println("returning empty tool calls")
return nil, ""
} }
for _, tc := range tcs {
return tcs, "", true tc.Function.Index = p.toolIndex
p.toolIndex++
}
return tcs, ""
} }
func NewToolParser(model *Model) *ToolParser { func NewToolParser(model *Model) *ToolParser {
// TODO: use new template parsing to get all tokens for the prefix
templateToolPrefix, _ := ToolPrefix(model.Template.Template) templateToolPrefix, _ := ToolPrefix(model.Template.Template)
templateToolPrefix = strings.TrimSpace(templateToolPrefix) templateToolPrefix = strings.TrimSpace(templateToolPrefix)
tmpl, ok := ToolTemplate(model) tmpl, ok := ToolTemplate(model)
@ -337,9 +367,10 @@ func NewToolParser(model *Model) *ToolParser {
} }
fmt.Println("setup state", state) fmt.Println("setup state", state)
return &ToolParser{ return &ToolParser{
tmpl: tmpl, tmpl: tmpl,
sb: &strings.Builder{}, sb: &strings.Builder{},
toolPrefix: templateToolPrefix, toolPrefix: templateToolPrefix,
state: state, state: state,
ParserState: ToolCallAccumulate,
} }
} }

View File

@ -53,7 +53,6 @@ func TestParseToolCalls(t *testing.T) {
output string output string
expectedToolCall []api.ToolCall expectedToolCall []api.ToolCall
expectedTokens string expectedTokens string
wantErr bool
}{ }{
{ {
name: "mistral invalid json", name: "mistral invalid json",
@ -61,7 +60,6 @@ func TestParseToolCalls(t *testing.T) {
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`, output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`,
expectedToolCall: []api.ToolCall{}, expectedToolCall: []api.ToolCall{},
expectedTokens: "", expectedTokens: "",
wantErr: true,
}, },
{ {
name: "mistral multiple tool calls - no prefix", name: "mistral multiple tool calls - no prefix",
@ -69,7 +67,6 @@ func TestParseToolCalls(t *testing.T) {
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2}, expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "", expectedTokens: "",
wantErr: false,
}, },
{ {
name: "mistral tool calls with text in between - no prefix", name: "mistral tool calls with text in between - no prefix",
@ -78,7 +75,6 @@ func TestParseToolCalls(t *testing.T) {
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2}, expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: `model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, expectedTokens: `model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
wantErr: false,
}, },
{ {
name: "mistral valid json - with prefix", name: "mistral valid json - with prefix",
@ -86,7 +82,6 @@ func TestParseToolCalls(t *testing.T) {
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2}, expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "", expectedTokens: "",
wantErr: false,
}, },
{ {
// In this case we'd be ignoring the text in between and just returning the tool calls // In this case we'd be ignoring the text in between and just returning the tool calls
@ -96,7 +91,6 @@ func TestParseToolCalls(t *testing.T) {
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2, t1, t2}, expectedToolCall: []api.ToolCall{t1, t2, t1, t2},
expectedTokens: "", expectedTokens: "",
wantErr: false,
}, },
{ {
name: "mistral incomplete json", name: "mistral incomplete json",
@ -104,7 +98,6 @@ func TestParseToolCalls(t *testing.T) {
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `, output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `,
expectedToolCall: []api.ToolCall{}, expectedToolCall: []api.ToolCall{},
expectedTokens: "", expectedTokens: "",
wantErr: true,
}, },
{ {
name: "mistral without tool token", name: "mistral without tool token",
@ -114,7 +107,6 @@ func TestParseToolCalls(t *testing.T) {
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{}, expectedToolCall: []api.ToolCall{},
expectedTokens: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, expectedTokens: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
wantErr: true,
}, },
{ {
name: "mistral without tool token - tool first", name: "mistral without tool token - tool first",
@ -122,7 +114,6 @@ func TestParseToolCalls(t *testing.T) {
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2}, expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "", expectedTokens: "",
wantErr: false,
}, },
{ {
name: "command-r-plus with json block", name: "command-r-plus with json block",
@ -147,7 +138,6 @@ func TestParseToolCalls(t *testing.T) {
` + "```", ` + "```",
expectedToolCall: []api.ToolCall{t1, t2}, expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "", expectedTokens: "",
wantErr: false,
}, },
{ {
name: "firefunction with functools", name: "firefunction with functools",
@ -155,7 +145,6 @@ func TestParseToolCalls(t *testing.T) {
output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2}, expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "", expectedTokens: "",
wantErr: false,
}, },
{ {
name: "llama3 with tool call tags", name: "llama3 with tool call tags",
@ -165,7 +154,6 @@ func TestParseToolCalls(t *testing.T) {
</tool_call>`, </tool_call>`,
expectedToolCall: []api.ToolCall{t1}, expectedToolCall: []api.ToolCall{t1},
expectedTokens: "", expectedTokens: "",
wantErr: false,
}, },
{ {
name: "xlam with tool_calls wrapper", name: "xlam with tool_calls wrapper",
@ -173,7 +161,6 @@ func TestParseToolCalls(t *testing.T) {
output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`,
expectedToolCall: []api.ToolCall{t1, t2}, expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "", expectedTokens: "",
wantErr: false,
}, },
{ {
name: "qwen2.5 with single tool call", name: "qwen2.5 with single tool call",
@ -181,15 +168,34 @@ func TestParseToolCalls(t *testing.T) {
output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`, output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
expectedToolCall: []api.ToolCall{t1}, expectedToolCall: []api.ToolCall{t1},
expectedTokens: "", expectedTokens: "",
wantErr: false,
}, },
{ {
name: "qwen with invalid tool token", name: "qwen with no tool prefix",
model: "qwen2.5-coder", model: "qwen2.5-coder",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2}, expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "", expectedTokens: "",
wantErr: false, },
{
name: "qwen with no tool calls",
model: "qwen2.5-coder",
output: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
expectedToolCall: []api.ToolCall{},
expectedTokens: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
},
{
name: "qwen with no tool prefix",
model: "qwen2.5-coder",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after call`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "some tokens after call",
},
{
name: "qwen with prefix",
model: "qwen2.5-coder",
output: `<tool_call> [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] </tool_call> some tokens after call`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
}, },
{ {
// tests the leftover logic as well // tests the leftover logic as well
@ -198,7 +204,6 @@ func TestParseToolCalls(t *testing.T) {
output: `<think>Okay, let me think what tool we should use...</think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`, output: `<think>Okay, let me think what tool we should use...</think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
expectedToolCall: []api.ToolCall{t1}, expectedToolCall: []api.ToolCall{t1},
expectedTokens: "<think>Okay, let me think what tool we should use...</think>", expectedTokens: "<think>Okay, let me think what tool we should use...</think>",
wantErr: false,
}, },
{ {
name: "qwen3 with single tool call and thinking spaces", name: "qwen3 with single tool call and thinking spaces",
@ -206,31 +211,20 @@ func TestParseToolCalls(t *testing.T) {
output: `<think>Okay, let me think what tool we should use...</think> <tool_call> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`, output: `<think>Okay, let me think what tool we should use...</think> <tool_call> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
expectedToolCall: []api.ToolCall{t1}, expectedToolCall: []api.ToolCall{t1},
expectedTokens: "<think>Okay, let me think what tool we should use...</think>", expectedTokens: "<think>Okay, let me think what tool we should use...</think>",
wantErr: false,
}, },
// {
// name: "qwen3 testing",
// model: "qwen3",
// output: `<think></think>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
// expectedToolCall: []api.ToolCall{},
// expectedTokens: `<think></think>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
// wantErr: true,
// },
// {
// name: "qwen3 testing 2",
// model: "qwen3",
// output: `<think></think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
// expectedToolCall: []api.ToolCall{t1},
// expectedTokens: `<think></think>`,
// wantErr: true,
// },
{ {
name: "qwen with no tool calls", name: "qwen3 testing",
model: "qwen2.5-coder", model: "qwen3",
output: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", output: `<think></think>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
expectedToolCall: []api.ToolCall{}, expectedToolCall: []api.ToolCall{},
expectedTokens: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", expectedTokens: `<think></think>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
wantErr: true, },
{
name: "qwen3 testing 2",
model: "qwen3",
output: `<think></think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: `<think></think>`,
}, },
{ {
name: "llama3.2 with tool call - no prefix", name: "llama3.2 with tool call - no prefix",
@ -238,7 +232,6 @@ func TestParseToolCalls(t *testing.T) {
output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`, output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
expectedToolCall: []api.ToolCall{t1}, expectedToolCall: []api.ToolCall{t1},
expectedTokens: "", expectedTokens: "",
wantErr: false,
}, },
{ {
name: "llama3.2 with incomplete tool call - no prefix", name: "llama3.2 with incomplete tool call - no prefix",
@ -246,7 +239,6 @@ func TestParseToolCalls(t *testing.T) {
output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, `, output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, `,
expectedToolCall: []api.ToolCall{}, expectedToolCall: []api.ToolCall{},
expectedTokens: "", expectedTokens: "",
wantErr: true,
}, },
{ {
name: "llama3.2 with tool call - in middle", name: "llama3.2 with tool call - in middle",
@ -254,7 +246,6 @@ func TestParseToolCalls(t *testing.T) {
output: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`, output: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
expectedToolCall: []api.ToolCall{}, expectedToolCall: []api.ToolCall{},
expectedTokens: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`, expectedTokens: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
wantErr: true,
}, },
{ {
name: "llama3.2 - fake tool prefix", name: "llama3.2 - fake tool prefix",
@ -262,7 +253,6 @@ func TestParseToolCalls(t *testing.T) {
output: `<tool_call>{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`, output: `<tool_call>{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
expectedToolCall: []api.ToolCall{}, expectedToolCall: []api.ToolCall{},
expectedTokens: `<tool_call>{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`, expectedTokens: `<tool_call>{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
wantErr: true,
}, },
} }
@ -298,7 +288,6 @@ func TestParseToolCalls(t *testing.T) {
m := &Model{Template: tmpl} m := &Model{Template: tmpl}
tp := NewToolParser(m) tp := NewToolParser(m)
got := []api.ToolCall{} got := []api.ToolCall{}
success := false
var actualTokens strings.Builder var actualTokens strings.Builder
tokens := strings.Fields(tt.output) tokens := strings.Fields(tt.output)
@ -306,40 +295,33 @@ func TestParseToolCalls(t *testing.T) {
add := true add := true
s := " " + tok s := " " + tok
// TODO(parthsareen): This logic is brittle as it mocks the logic in route, however can if !tp.Done {
if tp.state != Done { toolCalls, leftover := tp.ParseToolCalls(s)
toolCalls, leftover, ok := tp.ParseToolCalls(s) switch tp.ParserState {
if (tp.state == GreedyToolWithPrefix || tp.state == GreedyToolNoPrefix || tp.state == ToolSuffix) || (tp.state == ForceTools && len(toolCalls) == 0) { case ToolCallFound:
continue
}
if tp.state == ContainsPartialPrefix {
// actualTokens.Reset()
actualTokens.WriteString(leftover)
t.Log("leftover", leftover)
add = false
// continue
}
if ok && len(toolCalls) > 0 {
success = true
got = append(got, toolCalls...) got = append(got, toolCalls...)
add = false add = false
// actualTokens.Reset() case ToolCallSendTokens:
actualTokens.WriteString(s)
add = false
case ToolCallAccumulate:
add = false
case ToolCallSendPartial:
actualTokens.WriteString(" " + leftover)
add = false
} }
} }
// s = strings.TrimSpace(s)
if add { if add {
actualTokens.WriteString(s) actualTokens.WriteString(s)
} }
} }
if !tt.wantErr { // Compare tool calls if we expect any
if diff := cmp.Diff(got, tt.expectedToolCall); diff != "" { if diff := cmp.Diff(got, tt.expectedToolCall); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("tool calls mismatch (-got +want):\n%s", diff)
}
}
if !success && !tt.wantErr {
t.Errorf("expected success but got errors")
} }
// Compare tokens if we expect any
stripped := strings.TrimSpace(actualTokens.String()) stripped := strings.TrimSpace(actualTokens.String())
if diff := cmp.Diff(stripped, tt.expectedTokens); diff != "" { if diff := cmp.Diff(stripped, tt.expectedTokens); diff != "" {
t.Log("actualTokens", stripped, "expectedTokens", tt.expectedTokens) t.Log("actualTokens", stripped, "expectedTokens", tt.expectedTokens)