This commit is contained in:
ParthSareen 2025-05-06 18:29:06 -07:00
parent 516a540df7
commit b5a982ecb0
3 changed files with 200 additions and 73 deletions

View File

@ -1517,13 +1517,16 @@ func (s *Server) ChatHandler(c *gin.Context) {
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
if len(req.Tools) > 0 && !tp.done {
if len(req.Tools) > 0 && tp.state != Done {
fmt.Println("checking tool calls")
toolCalls, ok := tp.ParseToolCalls(r.Content)
if tp.state == PartialTool {
fmt.Println("partial tool, returning")
toolCalls, leftover, ok := tp.ParseToolCalls(r.Content)
if (tp.state == GreedyToolWithPrefix || tp.state == GreedyToolNoPrefix || tp.state == ToolSuffix) || (tp.state == ForceTools && len(toolCalls) == 0) {
return
}
if tp.state == ContainsPartialPrefix {
fmt.Println("sending tokens", leftover)
res.Message.Content = leftover
}
if ok && len(toolCalls) > 0 {
res.Message.ToolCalls = toolCalls
for i := range toolCalls {
@ -1535,6 +1538,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
}
fmt.Println("sending response", res.Message.Content)
ch <- res
}); err != nil {
ch <- gin.H{"error": err.Error()}

View File

@ -17,17 +17,42 @@ import (
type State int
const (
NoTool State = iota
PartialTool
ToolCall
SendTokens State = iota
GreedyToolWithPrefix
GreedyToolNoPrefix
// ToolCall
ForceTools
ToolSuffix
ContainsPartialPrefix
Done
)
func (s State) String() string {
switch s {
case SendTokens:
return "SendTokens"
case GreedyToolWithPrefix:
return "GreedyToolWithPrefix"
case GreedyToolNoPrefix:
return "GreedyToolNoPrefix"
case ForceTools:
return "ForceTools"
case ToolSuffix:
return "ToolSuffix"
case Done:
return "Done"
case ContainsPartialPrefix:
return "PartialPrefix"
default:
return fmt.Sprintf("Unknown State (%d)", s)
}
}
type ToolParser struct {
tmpl *gotmpl.Template
state State
sb *strings.Builder
toolPrefix string
done bool
}
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
@ -49,8 +74,6 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
return nil, false, false
}
// slog.Debug("template", "template", b.String())
// ! this can be either a map or an array
var temp any
err := jsonv2.Unmarshal(b.Bytes(), &temp)
@ -88,7 +111,6 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
if len(templateObjects) == 0 {
return nil, false, false
}
// fmt.Println("template objects", templateObjects)
// find the keys that correspond to the name and arguments fields
var name, arguments string
@ -142,81 +164,134 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
}
slog.Debug("parsed tool calls", "count", len(toolCalls))
return toolCalls, len(toolCalls) > 0, true
return toolCalls, false, true
}
func (p *ToolParser) updateState(ok bool, partial bool, tcs []api.ToolCall) {
switch {
case !ok && !partial && p.state == ForceTools:
fmt.Println("Case: !ok && !partial && ForceTools - staying in force tools, resetting buffer")
// force partial tool if we have a prefix
// no op and stay in force tools
p.sb.Reset()
case !ok && !partial:
fmt.Println("Case: !ok && !partial")
fmt.Println("state", p.state)
if p.state == GreedyToolNoPrefix {
fmt.Println(" Subcase: GreedyToolNoPrefix - marking as done")
p.state = Done
}
if p.state == GreedyToolWithPrefix {
fmt.Println(" Subcase: GreedyToolWithPrefix - switching to SendTokens")
p.state = SendTokens
}
p.sb.Reset()
case !ok && partial:
fmt.Println("Case: !ok && partial - accumulating partial content")
// ! acucumulate
case len(tcs) > 0:
fmt.Println("Case: tool calls found")
// do not parse again in the greedy JSON case as soon as we have a tool call
if p.state == GreedyToolWithPrefix {
p.state = SendTokens
} else if p.state == GreedyToolNoPrefix {
fmt.Println(" Subcase: Greedy modes - marking done and switching to SendTokens")
p.state = Done
}
p.sb.Reset()
}
}
// ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing.
// Returns tool calls, whether parsing is incomplete, and any errors.
func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, bool) {
func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, string, bool) {
p.sb.WriteString(s)
s = p.sb.String()
s = strings.TrimSpace(s)
slog.Debug("parse tool calls", "content", s)
if len(s) == 0 {
return nil, false
return nil, "", false
}
hasPrefix := false
s, hasPrefix := strings.CutPrefix(s, p.toolPrefix)
fmt.Println("hasPrefix", hasPrefix)
var tcs []api.ToolCall
var partial bool
var ok bool
if p.toolPrefix != "" {
if strings.HasPrefix(s, p.toolPrefix) {
s = strings.TrimSpace(s[len(p.toolPrefix):])
slog.Debug("tool prefix", "prefix", p.toolPrefix, "content", s)
p.state = PartialTool
hasPrefix = true
// Special token end case
} else if strings.HasSuffix(s, p.toolPrefix[2:]) {
p.state = PartialTool
p.sb.Reset()
if hasPrefix {
p.state = ForceTools
slog.Debug("tool prefix in prefix", "prefix", p.toolPrefix, "content", s)
// partial tool possibly
} else if strings.HasPrefix(p.toolPrefix, s) {
slog.Debug("tool prefix partially", "prefix", p.toolPrefix, "content", s)
// TODO: could possibly err maybe this should be greedy instead?
p.state = ForceTools
return nil, "", false
} else if strings.Contains(s, p.toolPrefix) {
idx := strings.Index(s, p.toolPrefix)
if idx != -1 {
// still keeps the prefix
p.state = ContainsPartialPrefix
p.sb.Reset()
p.sb.WriteString(s[idx:])
return nil, s[:idx], false
}
}
// Special token end case
// if s, ok := strings.CutSuffix(s, p.toolPrefix[2:]); ok {
if strings.HasSuffix(s, p.toolPrefix[2:]) {
// can be with string or just the token
if hasPrefix {
s = strings.TrimSpace(s[:len(s)-(len(p.toolPrefix)+1)])
} else {
p.state = ToolSuffix
p.sb.Reset()
return nil, "", false
}
slog.Debug("setting to no tool", "content", s)
return nil, false
}
}
tcs, partial, ok := p.parseJSONToolCalls(s)
// TODO: figure out how to return the remaining string if not partial anymore
// update state
switch {
case !ok && !partial && hasPrefix:
p.state = PartialTool
case !ok && !partial:
p.state = NoTool
case !ok && partial:
p.state = PartialTool
case len(tcs) > 0:
p.state = ToolCall
fmt.Println("s before parsing", s)
if p.state == SendTokens {
fmt.Println("returning nil cause of send tokens")
return nil, "", false
}
if p.state == NoTool || p.state == ToolCall {
slog.Debug("resetting string builder", "state", p.state)
p.sb.Reset()
}
if !ok {
return nil, false
}
tcs, partial, ok = p.parseJSONToolCalls(s)
slog.Debug("returning tool calls", "tool calls", tcs)
fmt.Println("end state", p.state)
if p.toolPrefix == "" {
p.done = true
}
fmt.Println("len tcs", len(tcs))
return tcs, true
p.updateState(ok, partial, tcs)
if !ok {
return nil, "", false
}
return tcs, "", true
}
func NewToolParser(model *Model) *ToolParser {
templateToolPrefix, _ := ToolPrefix(model.Template.Template)
templateToolPrefix = strings.TrimSpace(templateToolPrefix)
slog.Debug("tool prefix", "prefix", templateToolPrefix)
tmpl, ok := ToolTemplate(model)
if !ok {
return nil
}
var state State
if templateToolPrefix == "" {
state = GreedyToolNoPrefix
} else {
state = GreedyToolWithPrefix
}
return &ToolParser{
tmpl: tmpl,
sb: &strings.Builder{},
toolPrefix: templateToolPrefix,
done: false,
state: state,
}
}

View File

@ -51,7 +51,7 @@ func TestParseToolCalls(t *testing.T) {
name string
model string
output string
token string
prefix string
expected []api.ToolCall
wantErr bool
}{
@ -59,23 +59,50 @@ func TestParseToolCalls(t *testing.T) {
name: "mistral invalid json",
model: "mistral",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`,
token: "[TOOL_CALLS]",
prefix: "[TOOL_CALLS]",
expected: []api.ToolCall{},
wantErr: true,
},
{
name: "mistral valid json",
name: "mistral multiple tool calls - no prefix",
model: "mistral",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
prefix: "[TOOL_CALLS]",
expected: []api.ToolCall{t1, t2},
wantErr: false,
},
{
name: "mistral tool calls with text in between - no prefix",
model: "mistral",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
prefix: "[TOOL_CALLS]",
expected: []api.ToolCall{t1, t2},
wantErr: false,
},
{
name: "mistral valid json - with prefix",
model: "mistral",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
token: "[TOOL_CALLS]",
prefix: "[TOOL_CALLS]",
expected: []api.ToolCall{t1, t2},
wantErr: false,
},
{
// In this case we'd be ignoring the text in between and just returning the tool calls
name: "mistral valid json with text in between - with prefix",
model: "mistral",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
prefix: "[TOOL_CALLS]",
expected: []api.ToolCall{t1, t2, t1, t2},
wantErr: false,
},
{
name: "mistral incomplete json",
model: "mistral",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `,
token: "[TOOL_CALLS]",
prefix: "[TOOL_CALLS]",
expected: []api.ToolCall{},
wantErr: true,
},
@ -85,7 +112,7 @@ func TestParseToolCalls(t *testing.T) {
output: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
token: "[TOOL_CALLS]",
prefix: "[TOOL_CALLS]",
expected: []api.ToolCall{},
wantErr: true,
},
@ -93,7 +120,7 @@ func TestParseToolCalls(t *testing.T) {
name: "mistral without tool token - tool first",
model: "mistral",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
token: "[TOOL_CALLS]",
prefix: "[TOOL_CALLS]",
expected: []api.ToolCall{t1, t2},
wantErr: false,
},
@ -118,7 +145,7 @@ func TestParseToolCalls(t *testing.T) {
}
]
` + "```",
token: "Action:",
prefix: "Action: ```json",
expected: []api.ToolCall{t1, t2},
wantErr: false,
},
@ -126,7 +153,7 @@ func TestParseToolCalls(t *testing.T) {
name: "firefunction with functools",
model: "firefunction",
output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
token: "functools",
prefix: "functools",
expected: []api.ToolCall{t1, t2},
wantErr: false,
},
@ -136,7 +163,7 @@ func TestParseToolCalls(t *testing.T) {
output: `<tool_call>
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
</tool_call>`,
token: "<tool_call>",
prefix: "<tool_call>",
expected: []api.ToolCall{t1},
wantErr: false,
},
@ -144,16 +171,15 @@ func TestParseToolCalls(t *testing.T) {
name: "xlam with tool_calls wrapper",
model: "xlam",
output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`,
token: "",
prefix: "",
expected: []api.ToolCall{t1, t2},
wantErr: false,
},
{
// TODO: fix the spacing issue
name: "qwen with single tool call",
model: "qwen2.5-coder",
output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
token: "<tool_call>",
output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
prefix: "<tool_call>",
expected: []api.ToolCall{t1},
wantErr: false,
},
@ -161,15 +187,31 @@ func TestParseToolCalls(t *testing.T) {
name: "qwen with invalid tool token",
model: "qwen2.5-coder",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
token: "[TOOL_CALLS]",
prefix: "[TOOL_CALLS]",
expected: []api.ToolCall{t1, t2},
wantErr: false,
},
{
name: "qwen3 with single tool call and thinking",
model: "qwen3",
output: `<think>Okay, let me think what tool we should use...</think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
prefix: "<tool_call>",
expected: []api.ToolCall{t1},
wantErr: false,
},
{
name: "qwen3 with single tool call and thinking spaces",
model: "qwen3",
output: `<think>Okay, let me think what tool we should use...</think> <tool_call> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
prefix: "<tool_call>",
expected: []api.ToolCall{t1},
wantErr: false,
},
{
name: "qwen with no tool calls",
model: "qwen2.5-coder",
output: " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
token: "",
prefix: "",
expected: []api.ToolCall{},
wantErr: true,
},
@ -211,11 +253,15 @@ func TestParseToolCalls(t *testing.T) {
tokens := strings.Fields(tt.output)
for _, tok := range tokens {
s := " " + tok
toolCalls, ok := tp.ParseToolCalls(s)
if ok {
success = true
var toolCalls []api.ToolCall
var ok bool
if tp.state != Done {
toolCalls, _, ok = tp.ParseToolCalls(s)
if ok {
success = true
}
got = append(got, toolCalls...)
}
got = append(got, toolCalls...)
}
if !tt.wantErr {
@ -230,3 +276,5 @@ func TestParseToolCalls(t *testing.T) {
})
}
}
// TODO: add tests to check string sent not just tool