wip
This commit is contained in:
parent
516a540df7
commit
b5a982ecb0
@ -1517,13 +1517,16 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(req.Tools) > 0 && !tp.done {
|
if len(req.Tools) > 0 && tp.state != Done {
|
||||||
fmt.Println("checking tool calls")
|
fmt.Println("checking tool calls")
|
||||||
toolCalls, ok := tp.ParseToolCalls(r.Content)
|
toolCalls, leftover, ok := tp.ParseToolCalls(r.Content)
|
||||||
if tp.state == PartialTool {
|
if (tp.state == GreedyToolWithPrefix || tp.state == GreedyToolNoPrefix || tp.state == ToolSuffix) || (tp.state == ForceTools && len(toolCalls) == 0) {
|
||||||
fmt.Println("partial tool, returning")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if tp.state == ContainsPartialPrefix {
|
||||||
|
fmt.Println("sending tokens", leftover)
|
||||||
|
res.Message.Content = leftover
|
||||||
|
}
|
||||||
if ok && len(toolCalls) > 0 {
|
if ok && len(toolCalls) > 0 {
|
||||||
res.Message.ToolCalls = toolCalls
|
res.Message.ToolCalls = toolCalls
|
||||||
for i := range toolCalls {
|
for i := range toolCalls {
|
||||||
@ -1535,6 +1538,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println("sending response", res.Message.Content)
|
||||||
ch <- res
|
ch <- res
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
|
169
server/tools.go
169
server/tools.go
@ -17,17 +17,42 @@ import (
|
|||||||
type State int
|
type State int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
NoTool State = iota
|
SendTokens State = iota
|
||||||
PartialTool
|
GreedyToolWithPrefix
|
||||||
ToolCall
|
GreedyToolNoPrefix
|
||||||
|
// ToolCall
|
||||||
|
ForceTools
|
||||||
|
ToolSuffix
|
||||||
|
ContainsPartialPrefix
|
||||||
|
Done
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (s State) String() string {
|
||||||
|
switch s {
|
||||||
|
case SendTokens:
|
||||||
|
return "SendTokens"
|
||||||
|
case GreedyToolWithPrefix:
|
||||||
|
return "GreedyToolWithPrefix"
|
||||||
|
case GreedyToolNoPrefix:
|
||||||
|
return "GreedyToolNoPrefix"
|
||||||
|
case ForceTools:
|
||||||
|
return "ForceTools"
|
||||||
|
case ToolSuffix:
|
||||||
|
return "ToolSuffix"
|
||||||
|
case Done:
|
||||||
|
return "Done"
|
||||||
|
case ContainsPartialPrefix:
|
||||||
|
return "PartialPrefix"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("Unknown State (%d)", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type ToolParser struct {
|
type ToolParser struct {
|
||||||
tmpl *gotmpl.Template
|
tmpl *gotmpl.Template
|
||||||
state State
|
state State
|
||||||
sb *strings.Builder
|
sb *strings.Builder
|
||||||
toolPrefix string
|
toolPrefix string
|
||||||
done bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
||||||
@ -49,8 +74,6 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
|||||||
return nil, false, false
|
return nil, false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// slog.Debug("template", "template", b.String())
|
|
||||||
|
|
||||||
// ! this can be either a map or an array
|
// ! this can be either a map or an array
|
||||||
var temp any
|
var temp any
|
||||||
err := jsonv2.Unmarshal(b.Bytes(), &temp)
|
err := jsonv2.Unmarshal(b.Bytes(), &temp)
|
||||||
@ -88,7 +111,6 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
|||||||
if len(templateObjects) == 0 {
|
if len(templateObjects) == 0 {
|
||||||
return nil, false, false
|
return nil, false, false
|
||||||
}
|
}
|
||||||
// fmt.Println("template objects", templateObjects)
|
|
||||||
|
|
||||||
// find the keys that correspond to the name and arguments fields
|
// find the keys that correspond to the name and arguments fields
|
||||||
var name, arguments string
|
var name, arguments string
|
||||||
@ -142,81 +164,134 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("parsed tool calls", "count", len(toolCalls))
|
slog.Debug("parsed tool calls", "count", len(toolCalls))
|
||||||
return toolCalls, len(toolCalls) > 0, true
|
return toolCalls, false, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ToolParser) updateState(ok bool, partial bool, tcs []api.ToolCall) {
|
||||||
|
switch {
|
||||||
|
case !ok && !partial && p.state == ForceTools:
|
||||||
|
fmt.Println("Case: !ok && !partial && ForceTools - staying in force tools, resetting buffer")
|
||||||
|
// force partial tool if we have a prefix
|
||||||
|
// no op and stay in force tools
|
||||||
|
p.sb.Reset()
|
||||||
|
case !ok && !partial:
|
||||||
|
fmt.Println("Case: !ok && !partial")
|
||||||
|
fmt.Println("state", p.state)
|
||||||
|
if p.state == GreedyToolNoPrefix {
|
||||||
|
fmt.Println(" Subcase: GreedyToolNoPrefix - marking as done")
|
||||||
|
p.state = Done
|
||||||
|
}
|
||||||
|
if p.state == GreedyToolWithPrefix {
|
||||||
|
fmt.Println(" Subcase: GreedyToolWithPrefix - switching to SendTokens")
|
||||||
|
p.state = SendTokens
|
||||||
|
}
|
||||||
|
p.sb.Reset()
|
||||||
|
case !ok && partial:
|
||||||
|
fmt.Println("Case: !ok && partial - accumulating partial content")
|
||||||
|
|
||||||
|
// ! acucumulate
|
||||||
|
|
||||||
|
case len(tcs) > 0:
|
||||||
|
fmt.Println("Case: tool calls found")
|
||||||
|
// do not parse again in the greedy JSON case as soon as we have a tool call
|
||||||
|
if p.state == GreedyToolWithPrefix {
|
||||||
|
p.state = SendTokens
|
||||||
|
} else if p.state == GreedyToolNoPrefix {
|
||||||
|
fmt.Println(" Subcase: Greedy modes - marking done and switching to SendTokens")
|
||||||
|
p.state = Done
|
||||||
|
}
|
||||||
|
p.sb.Reset()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing.
|
// ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing.
|
||||||
// Returns tool calls, whether parsing is incomplete, and any errors.
|
// Returns tool calls, whether parsing is incomplete, and any errors.
|
||||||
func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, bool) {
|
func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, string, bool) {
|
||||||
p.sb.WriteString(s)
|
p.sb.WriteString(s)
|
||||||
s = p.sb.String()
|
s = p.sb.String()
|
||||||
s = strings.TrimSpace(s)
|
s = strings.TrimSpace(s)
|
||||||
slog.Debug("parse tool calls", "content", s)
|
slog.Debug("parse tool calls", "content", s)
|
||||||
|
|
||||||
if len(s) == 0 {
|
if len(s) == 0 {
|
||||||
return nil, false
|
return nil, "", false
|
||||||
}
|
}
|
||||||
hasPrefix := false
|
s, hasPrefix := strings.CutPrefix(s, p.toolPrefix)
|
||||||
|
fmt.Println("hasPrefix", hasPrefix)
|
||||||
|
var tcs []api.ToolCall
|
||||||
|
var partial bool
|
||||||
|
var ok bool
|
||||||
|
|
||||||
if p.toolPrefix != "" {
|
if p.toolPrefix != "" {
|
||||||
if strings.HasPrefix(s, p.toolPrefix) {
|
if hasPrefix {
|
||||||
s = strings.TrimSpace(s[len(p.toolPrefix):])
|
p.state = ForceTools
|
||||||
slog.Debug("tool prefix", "prefix", p.toolPrefix, "content", s)
|
slog.Debug("tool prefix in prefix", "prefix", p.toolPrefix, "content", s)
|
||||||
p.state = PartialTool
|
// partial tool possibly
|
||||||
hasPrefix = true
|
} else if strings.HasPrefix(p.toolPrefix, s) {
|
||||||
|
slog.Debug("tool prefix partially", "prefix", p.toolPrefix, "content", s)
|
||||||
|
// TODO: could possibly err maybe this should be greedy instead?
|
||||||
|
p.state = ForceTools
|
||||||
|
return nil, "", false
|
||||||
|
} else if strings.Contains(s, p.toolPrefix) {
|
||||||
|
idx := strings.Index(s, p.toolPrefix)
|
||||||
|
if idx != -1 {
|
||||||
|
// still keeps the prefix
|
||||||
|
p.state = ContainsPartialPrefix
|
||||||
|
p.sb.Reset()
|
||||||
|
p.sb.WriteString(s[idx:])
|
||||||
|
return nil, s[:idx], false
|
||||||
|
}
|
||||||
|
}
|
||||||
// Special token end case
|
// Special token end case
|
||||||
} else if strings.HasSuffix(s, p.toolPrefix[2:]) {
|
// if s, ok := strings.CutSuffix(s, p.toolPrefix[2:]); ok {
|
||||||
p.state = PartialTool
|
if strings.HasSuffix(s, p.toolPrefix[2:]) {
|
||||||
|
// can be with string or just the token
|
||||||
|
if hasPrefix {
|
||||||
|
s = strings.TrimSpace(s[:len(s)-(len(p.toolPrefix)+1)])
|
||||||
|
} else {
|
||||||
|
p.state = ToolSuffix
|
||||||
p.sb.Reset()
|
p.sb.Reset()
|
||||||
|
return nil, "", false
|
||||||
|
}
|
||||||
slog.Debug("setting to no tool", "content", s)
|
slog.Debug("setting to no tool", "content", s)
|
||||||
return nil, false
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tcs, partial, ok := p.parseJSONToolCalls(s)
|
fmt.Println("s before parsing", s)
|
||||||
|
if p.state == SendTokens {
|
||||||
// TODO: figure out how to return the remaining string if not partial anymore
|
fmt.Println("returning nil cause of send tokens")
|
||||||
// update state
|
return nil, "", false
|
||||||
switch {
|
|
||||||
case !ok && !partial && hasPrefix:
|
|
||||||
p.state = PartialTool
|
|
||||||
case !ok && !partial:
|
|
||||||
p.state = NoTool
|
|
||||||
case !ok && partial:
|
|
||||||
p.state = PartialTool
|
|
||||||
case len(tcs) > 0:
|
|
||||||
p.state = ToolCall
|
|
||||||
}
|
}
|
||||||
|
tcs, partial, ok = p.parseJSONToolCalls(s)
|
||||||
if p.state == NoTool || p.state == ToolCall {
|
|
||||||
slog.Debug("resetting string builder", "state", p.state)
|
|
||||||
p.sb.Reset()
|
|
||||||
}
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Debug("returning tool calls", "tool calls", tcs)
|
slog.Debug("returning tool calls", "tool calls", tcs)
|
||||||
fmt.Println("end state", p.state)
|
fmt.Println("end state", p.state)
|
||||||
if p.toolPrefix == "" {
|
|
||||||
p.done = true
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println("len tcs", len(tcs))
|
fmt.Println("len tcs", len(tcs))
|
||||||
return tcs, true
|
p.updateState(ok, partial, tcs)
|
||||||
|
if !ok {
|
||||||
|
return nil, "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
return tcs, "", true
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewToolParser(model *Model) *ToolParser {
|
func NewToolParser(model *Model) *ToolParser {
|
||||||
templateToolPrefix, _ := ToolPrefix(model.Template.Template)
|
templateToolPrefix, _ := ToolPrefix(model.Template.Template)
|
||||||
|
templateToolPrefix = strings.TrimSpace(templateToolPrefix)
|
||||||
slog.Debug("tool prefix", "prefix", templateToolPrefix)
|
slog.Debug("tool prefix", "prefix", templateToolPrefix)
|
||||||
tmpl, ok := ToolTemplate(model)
|
tmpl, ok := ToolTemplate(model)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var state State
|
||||||
|
if templateToolPrefix == "" {
|
||||||
|
state = GreedyToolNoPrefix
|
||||||
|
} else {
|
||||||
|
state = GreedyToolWithPrefix
|
||||||
|
}
|
||||||
return &ToolParser{
|
return &ToolParser{
|
||||||
tmpl: tmpl,
|
tmpl: tmpl,
|
||||||
sb: &strings.Builder{},
|
sb: &strings.Builder{},
|
||||||
toolPrefix: templateToolPrefix,
|
toolPrefix: templateToolPrefix,
|
||||||
done: false,
|
state: state,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -51,7 +51,7 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
model string
|
model string
|
||||||
output string
|
output string
|
||||||
token string
|
prefix string
|
||||||
expected []api.ToolCall
|
expected []api.ToolCall
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
@ -59,23 +59,50 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
name: "mistral invalid json",
|
name: "mistral invalid json",
|
||||||
model: "mistral",
|
model: "mistral",
|
||||||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`,
|
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`,
|
||||||
token: "[TOOL_CALLS]",
|
prefix: "[TOOL_CALLS]",
|
||||||
expected: []api.ToolCall{},
|
expected: []api.ToolCall{},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "mistral valid json",
|
name: "mistral multiple tool calls - no prefix",
|
||||||
|
model: "mistral",
|
||||||
|
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
|
prefix: "[TOOL_CALLS]",
|
||||||
|
expected: []api.ToolCall{t1, t2},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mistral tool calls with text in between - no prefix",
|
||||||
|
model: "mistral",
|
||||||
|
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||||
|
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
|
prefix: "[TOOL_CALLS]",
|
||||||
|
expected: []api.ToolCall{t1, t2},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mistral valid json - with prefix",
|
||||||
model: "mistral",
|
model: "mistral",
|
||||||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
token: "[TOOL_CALLS]",
|
prefix: "[TOOL_CALLS]",
|
||||||
expected: []api.ToolCall{t1, t2},
|
expected: []api.ToolCall{t1, t2},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
// In this case we'd be ignoring the text in between and just returning the tool calls
|
||||||
|
name: "mistral valid json with text in between - with prefix",
|
||||||
|
model: "mistral",
|
||||||
|
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||||
|
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
|
prefix: "[TOOL_CALLS]",
|
||||||
|
expected: []api.ToolCall{t1, t2, t1, t2},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "mistral incomplete json",
|
name: "mistral incomplete json",
|
||||||
model: "mistral",
|
model: "mistral",
|
||||||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `,
|
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `,
|
||||||
token: "[TOOL_CALLS]",
|
prefix: "[TOOL_CALLS]",
|
||||||
expected: []api.ToolCall{},
|
expected: []api.ToolCall{},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
@ -85,7 +112,7 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
output: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
|
output: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
|
||||||
|
|
||||||
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
token: "[TOOL_CALLS]",
|
prefix: "[TOOL_CALLS]",
|
||||||
expected: []api.ToolCall{},
|
expected: []api.ToolCall{},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
@ -93,7 +120,7 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
name: "mistral without tool token - tool first",
|
name: "mistral without tool token - tool first",
|
||||||
model: "mistral",
|
model: "mistral",
|
||||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
token: "[TOOL_CALLS]",
|
prefix: "[TOOL_CALLS]",
|
||||||
expected: []api.ToolCall{t1, t2},
|
expected: []api.ToolCall{t1, t2},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
@ -118,7 +145,7 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
` + "```",
|
` + "```",
|
||||||
token: "Action:",
|
prefix: "Action: ```json",
|
||||||
expected: []api.ToolCall{t1, t2},
|
expected: []api.ToolCall{t1, t2},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
@ -126,7 +153,7 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
name: "firefunction with functools",
|
name: "firefunction with functools",
|
||||||
model: "firefunction",
|
model: "firefunction",
|
||||||
output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
token: "functools",
|
prefix: "functools",
|
||||||
expected: []api.ToolCall{t1, t2},
|
expected: []api.ToolCall{t1, t2},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
@ -136,7 +163,7 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
output: `<tool_call>
|
output: `<tool_call>
|
||||||
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
|
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
|
||||||
</tool_call>`,
|
</tool_call>`,
|
||||||
token: "<tool_call>",
|
prefix: "<tool_call>",
|
||||||
expected: []api.ToolCall{t1},
|
expected: []api.ToolCall{t1},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
@ -144,16 +171,15 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
name: "xlam with tool_calls wrapper",
|
name: "xlam with tool_calls wrapper",
|
||||||
model: "xlam",
|
model: "xlam",
|
||||||
output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`,
|
output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`,
|
||||||
token: "",
|
prefix: "",
|
||||||
expected: []api.ToolCall{t1, t2},
|
expected: []api.ToolCall{t1, t2},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// TODO: fix the spacing issue
|
|
||||||
name: "qwen with single tool call",
|
name: "qwen with single tool call",
|
||||||
model: "qwen2.5-coder",
|
model: "qwen2.5-coder",
|
||||||
output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
|
output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
|
||||||
token: "<tool_call>",
|
prefix: "<tool_call>",
|
||||||
expected: []api.ToolCall{t1},
|
expected: []api.ToolCall{t1},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
@ -161,15 +187,31 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
name: "qwen with invalid tool token",
|
name: "qwen with invalid tool token",
|
||||||
model: "qwen2.5-coder",
|
model: "qwen2.5-coder",
|
||||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
token: "[TOOL_CALLS]",
|
prefix: "[TOOL_CALLS]",
|
||||||
expected: []api.ToolCall{t1, t2},
|
expected: []api.ToolCall{t1, t2},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "qwen3 with single tool call and thinking",
|
||||||
|
model: "qwen3",
|
||||||
|
output: `<think>Okay, let me think what tool we should use...</think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
|
||||||
|
prefix: "<tool_call>",
|
||||||
|
expected: []api.ToolCall{t1},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen3 with single tool call and thinking spaces",
|
||||||
|
model: "qwen3",
|
||||||
|
output: `<think>Okay, let me think what tool we should use...</think> <tool_call> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
|
prefix: "<tool_call>",
|
||||||
|
expected: []api.ToolCall{t1},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "qwen with no tool calls",
|
name: "qwen with no tool calls",
|
||||||
model: "qwen2.5-coder",
|
model: "qwen2.5-coder",
|
||||||
output: " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
|
output: " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
|
||||||
token: "",
|
prefix: "",
|
||||||
expected: []api.ToolCall{},
|
expected: []api.ToolCall{},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
@ -211,12 +253,16 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
tokens := strings.Fields(tt.output)
|
tokens := strings.Fields(tt.output)
|
||||||
for _, tok := range tokens {
|
for _, tok := range tokens {
|
||||||
s := " " + tok
|
s := " " + tok
|
||||||
toolCalls, ok := tp.ParseToolCalls(s)
|
var toolCalls []api.ToolCall
|
||||||
|
var ok bool
|
||||||
|
if tp.state != Done {
|
||||||
|
toolCalls, _, ok = tp.ParseToolCalls(s)
|
||||||
if ok {
|
if ok {
|
||||||
success = true
|
success = true
|
||||||
}
|
}
|
||||||
got = append(got, toolCalls...)
|
got = append(got, toolCalls...)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !tt.wantErr {
|
if !tt.wantErr {
|
||||||
if diff := cmp.Diff(got, tt.expected); diff != "" {
|
if diff := cmp.Diff(got, tt.expected); diff != "" {
|
||||||
@ -230,3 +276,5 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: add tests to check string sent not just tool
|
||||||
|
Loading…
x
Reference in New Issue
Block a user