wip
This commit is contained in:
parent
516a540df7
commit
b5a982ecb0
@ -1517,13 +1517,16 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
if len(req.Tools) > 0 && !tp.done {
|
||||
if len(req.Tools) > 0 && tp.state != Done {
|
||||
fmt.Println("checking tool calls")
|
||||
toolCalls, ok := tp.ParseToolCalls(r.Content)
|
||||
if tp.state == PartialTool {
|
||||
fmt.Println("partial tool, returning")
|
||||
toolCalls, leftover, ok := tp.ParseToolCalls(r.Content)
|
||||
if (tp.state == GreedyToolWithPrefix || tp.state == GreedyToolNoPrefix || tp.state == ToolSuffix) || (tp.state == ForceTools && len(toolCalls) == 0) {
|
||||
return
|
||||
}
|
||||
if tp.state == ContainsPartialPrefix {
|
||||
fmt.Println("sending tokens", leftover)
|
||||
res.Message.Content = leftover
|
||||
}
|
||||
if ok && len(toolCalls) > 0 {
|
||||
res.Message.ToolCalls = toolCalls
|
||||
for i := range toolCalls {
|
||||
@ -1535,6 +1538,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("sending response", res.Message.Content)
|
||||
ch <- res
|
||||
}); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
|
173
server/tools.go
173
server/tools.go
@ -17,17 +17,42 @@ import (
|
||||
type State int
|
||||
|
||||
const (
|
||||
NoTool State = iota
|
||||
PartialTool
|
||||
ToolCall
|
||||
SendTokens State = iota
|
||||
GreedyToolWithPrefix
|
||||
GreedyToolNoPrefix
|
||||
// ToolCall
|
||||
ForceTools
|
||||
ToolSuffix
|
||||
ContainsPartialPrefix
|
||||
Done
|
||||
)
|
||||
|
||||
func (s State) String() string {
|
||||
switch s {
|
||||
case SendTokens:
|
||||
return "SendTokens"
|
||||
case GreedyToolWithPrefix:
|
||||
return "GreedyToolWithPrefix"
|
||||
case GreedyToolNoPrefix:
|
||||
return "GreedyToolNoPrefix"
|
||||
case ForceTools:
|
||||
return "ForceTools"
|
||||
case ToolSuffix:
|
||||
return "ToolSuffix"
|
||||
case Done:
|
||||
return "Done"
|
||||
case ContainsPartialPrefix:
|
||||
return "PartialPrefix"
|
||||
default:
|
||||
return fmt.Sprintf("Unknown State (%d)", s)
|
||||
}
|
||||
}
|
||||
|
||||
type ToolParser struct {
|
||||
tmpl *gotmpl.Template
|
||||
state State
|
||||
sb *strings.Builder
|
||||
toolPrefix string
|
||||
done bool
|
||||
}
|
||||
|
||||
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
||||
@ -49,8 +74,6 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
// slog.Debug("template", "template", b.String())
|
||||
|
||||
// ! this can be either a map or an array
|
||||
var temp any
|
||||
err := jsonv2.Unmarshal(b.Bytes(), &temp)
|
||||
@ -88,7 +111,6 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
||||
if len(templateObjects) == 0 {
|
||||
return nil, false, false
|
||||
}
|
||||
// fmt.Println("template objects", templateObjects)
|
||||
|
||||
// find the keys that correspond to the name and arguments fields
|
||||
var name, arguments string
|
||||
@ -142,81 +164,134 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
||||
}
|
||||
|
||||
slog.Debug("parsed tool calls", "count", len(toolCalls))
|
||||
return toolCalls, len(toolCalls) > 0, true
|
||||
return toolCalls, false, true
|
||||
}
|
||||
|
||||
func (p *ToolParser) updateState(ok bool, partial bool, tcs []api.ToolCall) {
|
||||
switch {
|
||||
case !ok && !partial && p.state == ForceTools:
|
||||
fmt.Println("Case: !ok && !partial && ForceTools - staying in force tools, resetting buffer")
|
||||
// force partial tool if we have a prefix
|
||||
// no op and stay in force tools
|
||||
p.sb.Reset()
|
||||
case !ok && !partial:
|
||||
fmt.Println("Case: !ok && !partial")
|
||||
fmt.Println("state", p.state)
|
||||
if p.state == GreedyToolNoPrefix {
|
||||
fmt.Println(" Subcase: GreedyToolNoPrefix - marking as done")
|
||||
p.state = Done
|
||||
}
|
||||
if p.state == GreedyToolWithPrefix {
|
||||
fmt.Println(" Subcase: GreedyToolWithPrefix - switching to SendTokens")
|
||||
p.state = SendTokens
|
||||
}
|
||||
p.sb.Reset()
|
||||
case !ok && partial:
|
||||
fmt.Println("Case: !ok && partial - accumulating partial content")
|
||||
|
||||
// ! acucumulate
|
||||
|
||||
case len(tcs) > 0:
|
||||
fmt.Println("Case: tool calls found")
|
||||
// do not parse again in the greedy JSON case as soon as we have a tool call
|
||||
if p.state == GreedyToolWithPrefix {
|
||||
p.state = SendTokens
|
||||
} else if p.state == GreedyToolNoPrefix {
|
||||
fmt.Println(" Subcase: Greedy modes - marking done and switching to SendTokens")
|
||||
p.state = Done
|
||||
}
|
||||
p.sb.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
// ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing.
|
||||
// Returns tool calls, whether parsing is incomplete, and any errors.
|
||||
func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, bool) {
|
||||
func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, string, bool) {
|
||||
p.sb.WriteString(s)
|
||||
s = p.sb.String()
|
||||
s = strings.TrimSpace(s)
|
||||
slog.Debug("parse tool calls", "content", s)
|
||||
|
||||
if len(s) == 0 {
|
||||
return nil, false
|
||||
return nil, "", false
|
||||
}
|
||||
hasPrefix := false
|
||||
s, hasPrefix := strings.CutPrefix(s, p.toolPrefix)
|
||||
fmt.Println("hasPrefix", hasPrefix)
|
||||
var tcs []api.ToolCall
|
||||
var partial bool
|
||||
var ok bool
|
||||
|
||||
if p.toolPrefix != "" {
|
||||
if strings.HasPrefix(s, p.toolPrefix) {
|
||||
s = strings.TrimSpace(s[len(p.toolPrefix):])
|
||||
slog.Debug("tool prefix", "prefix", p.toolPrefix, "content", s)
|
||||
p.state = PartialTool
|
||||
hasPrefix = true
|
||||
// Special token end case
|
||||
} else if strings.HasSuffix(s, p.toolPrefix[2:]) {
|
||||
p.state = PartialTool
|
||||
p.sb.Reset()
|
||||
if hasPrefix {
|
||||
p.state = ForceTools
|
||||
slog.Debug("tool prefix in prefix", "prefix", p.toolPrefix, "content", s)
|
||||
// partial tool possibly
|
||||
} else if strings.HasPrefix(p.toolPrefix, s) {
|
||||
slog.Debug("tool prefix partially", "prefix", p.toolPrefix, "content", s)
|
||||
// TODO: could possibly err maybe this should be greedy instead?
|
||||
p.state = ForceTools
|
||||
return nil, "", false
|
||||
} else if strings.Contains(s, p.toolPrefix) {
|
||||
idx := strings.Index(s, p.toolPrefix)
|
||||
if idx != -1 {
|
||||
// still keeps the prefix
|
||||
p.state = ContainsPartialPrefix
|
||||
p.sb.Reset()
|
||||
p.sb.WriteString(s[idx:])
|
||||
return nil, s[:idx], false
|
||||
}
|
||||
}
|
||||
// Special token end case
|
||||
// if s, ok := strings.CutSuffix(s, p.toolPrefix[2:]); ok {
|
||||
if strings.HasSuffix(s, p.toolPrefix[2:]) {
|
||||
// can be with string or just the token
|
||||
if hasPrefix {
|
||||
s = strings.TrimSpace(s[:len(s)-(len(p.toolPrefix)+1)])
|
||||
} else {
|
||||
p.state = ToolSuffix
|
||||
p.sb.Reset()
|
||||
return nil, "", false
|
||||
}
|
||||
slog.Debug("setting to no tool", "content", s)
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
tcs, partial, ok := p.parseJSONToolCalls(s)
|
||||
|
||||
// TODO: figure out how to return the remaining string if not partial anymore
|
||||
// update state
|
||||
switch {
|
||||
case !ok && !partial && hasPrefix:
|
||||
p.state = PartialTool
|
||||
case !ok && !partial:
|
||||
p.state = NoTool
|
||||
case !ok && partial:
|
||||
p.state = PartialTool
|
||||
case len(tcs) > 0:
|
||||
p.state = ToolCall
|
||||
fmt.Println("s before parsing", s)
|
||||
if p.state == SendTokens {
|
||||
fmt.Println("returning nil cause of send tokens")
|
||||
return nil, "", false
|
||||
}
|
||||
|
||||
if p.state == NoTool || p.state == ToolCall {
|
||||
slog.Debug("resetting string builder", "state", p.state)
|
||||
p.sb.Reset()
|
||||
}
|
||||
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
tcs, partial, ok = p.parseJSONToolCalls(s)
|
||||
slog.Debug("returning tool calls", "tool calls", tcs)
|
||||
fmt.Println("end state", p.state)
|
||||
if p.toolPrefix == "" {
|
||||
p.done = true
|
||||
}
|
||||
|
||||
fmt.Println("len tcs", len(tcs))
|
||||
return tcs, true
|
||||
p.updateState(ok, partial, tcs)
|
||||
if !ok {
|
||||
return nil, "", false
|
||||
}
|
||||
|
||||
return tcs, "", true
|
||||
}
|
||||
|
||||
func NewToolParser(model *Model) *ToolParser {
|
||||
templateToolPrefix, _ := ToolPrefix(model.Template.Template)
|
||||
templateToolPrefix = strings.TrimSpace(templateToolPrefix)
|
||||
slog.Debug("tool prefix", "prefix", templateToolPrefix)
|
||||
tmpl, ok := ToolTemplate(model)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
var state State
|
||||
if templateToolPrefix == "" {
|
||||
state = GreedyToolNoPrefix
|
||||
} else {
|
||||
state = GreedyToolWithPrefix
|
||||
}
|
||||
return &ToolParser{
|
||||
tmpl: tmpl,
|
||||
sb: &strings.Builder{},
|
||||
toolPrefix: templateToolPrefix,
|
||||
done: false,
|
||||
state: state,
|
||||
}
|
||||
}
|
||||
|
@ -51,7 +51,7 @@ func TestParseToolCalls(t *testing.T) {
|
||||
name string
|
||||
model string
|
||||
output string
|
||||
token string
|
||||
prefix string
|
||||
expected []api.ToolCall
|
||||
wantErr bool
|
||||
}{
|
||||
@ -59,23 +59,50 @@ func TestParseToolCalls(t *testing.T) {
|
||||
name: "mistral invalid json",
|
||||
model: "mistral",
|
||||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`,
|
||||
token: "[TOOL_CALLS]",
|
||||
prefix: "[TOOL_CALLS]",
|
||||
expected: []api.ToolCall{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "mistral valid json",
|
||||
name: "mistral multiple tool calls - no prefix",
|
||||
model: "mistral",
|
||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
prefix: "[TOOL_CALLS]",
|
||||
expected: []api.ToolCall{t1, t2},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "mistral tool calls with text in between - no prefix",
|
||||
model: "mistral",
|
||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
prefix: "[TOOL_CALLS]",
|
||||
expected: []api.ToolCall{t1, t2},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "mistral valid json - with prefix",
|
||||
model: "mistral",
|
||||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
token: "[TOOL_CALLS]",
|
||||
prefix: "[TOOL_CALLS]",
|
||||
expected: []api.ToolCall{t1, t2},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
// In this case we'd be ignoring the text in between and just returning the tool calls
|
||||
name: "mistral valid json with text in between - with prefix",
|
||||
model: "mistral",
|
||||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
prefix: "[TOOL_CALLS]",
|
||||
expected: []api.ToolCall{t1, t2, t1, t2},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "mistral incomplete json",
|
||||
model: "mistral",
|
||||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `,
|
||||
token: "[TOOL_CALLS]",
|
||||
prefix: "[TOOL_CALLS]",
|
||||
expected: []api.ToolCall{},
|
||||
wantErr: true,
|
||||
},
|
||||
@ -85,7 +112,7 @@ func TestParseToolCalls(t *testing.T) {
|
||||
output: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
|
||||
|
||||
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
token: "[TOOL_CALLS]",
|
||||
prefix: "[TOOL_CALLS]",
|
||||
expected: []api.ToolCall{},
|
||||
wantErr: true,
|
||||
},
|
||||
@ -93,7 +120,7 @@ func TestParseToolCalls(t *testing.T) {
|
||||
name: "mistral without tool token - tool first",
|
||||
model: "mistral",
|
||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
token: "[TOOL_CALLS]",
|
||||
prefix: "[TOOL_CALLS]",
|
||||
expected: []api.ToolCall{t1, t2},
|
||||
wantErr: false,
|
||||
},
|
||||
@ -118,7 +145,7 @@ func TestParseToolCalls(t *testing.T) {
|
||||
}
|
||||
]
|
||||
` + "```",
|
||||
token: "Action:",
|
||||
prefix: "Action: ```json",
|
||||
expected: []api.ToolCall{t1, t2},
|
||||
wantErr: false,
|
||||
},
|
||||
@ -126,7 +153,7 @@ func TestParseToolCalls(t *testing.T) {
|
||||
name: "firefunction with functools",
|
||||
model: "firefunction",
|
||||
output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
token: "functools",
|
||||
prefix: "functools",
|
||||
expected: []api.ToolCall{t1, t2},
|
||||
wantErr: false,
|
||||
},
|
||||
@ -136,7 +163,7 @@ func TestParseToolCalls(t *testing.T) {
|
||||
output: `<tool_call>
|
||||
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
|
||||
</tool_call>`,
|
||||
token: "<tool_call>",
|
||||
prefix: "<tool_call>",
|
||||
expected: []api.ToolCall{t1},
|
||||
wantErr: false,
|
||||
},
|
||||
@ -144,16 +171,15 @@ func TestParseToolCalls(t *testing.T) {
|
||||
name: "xlam with tool_calls wrapper",
|
||||
model: "xlam",
|
||||
output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`,
|
||||
token: "",
|
||||
prefix: "",
|
||||
expected: []api.ToolCall{t1, t2},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
// TODO: fix the spacing issue
|
||||
name: "qwen with single tool call",
|
||||
model: "qwen2.5-coder",
|
||||
output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
token: "<tool_call>",
|
||||
output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
|
||||
prefix: "<tool_call>",
|
||||
expected: []api.ToolCall{t1},
|
||||
wantErr: false,
|
||||
},
|
||||
@ -161,15 +187,31 @@ func TestParseToolCalls(t *testing.T) {
|
||||
name: "qwen with invalid tool token",
|
||||
model: "qwen2.5-coder",
|
||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
token: "[TOOL_CALLS]",
|
||||
prefix: "[TOOL_CALLS]",
|
||||
expected: []api.ToolCall{t1, t2},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "qwen3 with single tool call and thinking",
|
||||
model: "qwen3",
|
||||
output: `<think>Okay, let me think what tool we should use...</think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
|
||||
prefix: "<tool_call>",
|
||||
expected: []api.ToolCall{t1},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "qwen3 with single tool call and thinking spaces",
|
||||
model: "qwen3",
|
||||
output: `<think>Okay, let me think what tool we should use...</think> <tool_call> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
prefix: "<tool_call>",
|
||||
expected: []api.ToolCall{t1},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "qwen with no tool calls",
|
||||
model: "qwen2.5-coder",
|
||||
output: " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
|
||||
token: "",
|
||||
prefix: "",
|
||||
expected: []api.ToolCall{},
|
||||
wantErr: true,
|
||||
},
|
||||
@ -211,11 +253,15 @@ func TestParseToolCalls(t *testing.T) {
|
||||
tokens := strings.Fields(tt.output)
|
||||
for _, tok := range tokens {
|
||||
s := " " + tok
|
||||
toolCalls, ok := tp.ParseToolCalls(s)
|
||||
if ok {
|
||||
success = true
|
||||
var toolCalls []api.ToolCall
|
||||
var ok bool
|
||||
if tp.state != Done {
|
||||
toolCalls, _, ok = tp.ParseToolCalls(s)
|
||||
if ok {
|
||||
success = true
|
||||
}
|
||||
got = append(got, toolCalls...)
|
||||
}
|
||||
got = append(got, toolCalls...)
|
||||
}
|
||||
|
||||
if !tt.wantErr {
|
||||
@ -230,3 +276,5 @@ func TestParseToolCalls(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: add tests to check string sent not just tool
|
||||
|
Loading…
x
Reference in New Issue
Block a user