checkpoint
This commit is contained in:
parent
779547fcde
commit
b8b9c0c7cf
@ -1485,7 +1485,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
ch := make(chan any)
|
ch := make(chan any)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
// var sb strings.Builder
|
|
||||||
var toolParser *ToolParser
|
var toolParser *ToolParser
|
||||||
if len(req.Tools) > 0 {
|
if len(req.Tools) > 0 {
|
||||||
toolParser = NewToolParser(m)
|
toolParser = NewToolParser(m)
|
||||||
@ -1518,6 +1517,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
|
|
||||||
if len(req.Tools) > 0 && !toolParser.Done {
|
if len(req.Tools) > 0 && !toolParser.Done {
|
||||||
toolCalls, leftover := toolParser.ParseToolCalls(r.Content)
|
toolCalls, leftover := toolParser.ParseToolCalls(r.Content)
|
||||||
|
// * This can be abstracted again to a .handleState(tp.state)
|
||||||
|
// * However, we'd need a flag to indicate whether to send the response or not
|
||||||
|
// * happy to take whatever is more idiomatic
|
||||||
switch toolParser.ParserState {
|
switch toolParser.ParserState {
|
||||||
case ToolCallAccumulate:
|
case ToolCallAccumulate:
|
||||||
// tokens are accumulated in the tool parser
|
// tokens are accumulated in the tool parser
|
||||||
@ -1526,7 +1528,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
// tokens are sent back in the response
|
// tokens are sent back in the response
|
||||||
case ToolCallSendPartial:
|
case ToolCallSendPartial:
|
||||||
// tokens not needed for parsing are sent back in the response
|
// tokens not needed for parsing are sent back in the response
|
||||||
res.Message.Content = leftover
|
if len(leftover) > 0 {
|
||||||
|
res.Message.Content = leftover
|
||||||
|
}
|
||||||
|
// ! state is needed as we need to not match on the other states
|
||||||
case ToolCallFound:
|
case ToolCallFound:
|
||||||
res.Message.ToolCalls = toolCalls
|
res.Message.ToolCalls = toolCalls
|
||||||
res.Message.Content = ""
|
res.Message.Content = ""
|
||||||
@ -1534,6 +1539,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println("sending response", res.Message.Content)
|
fmt.Println("sending response", res.Message.Content)
|
||||||
|
// * this is where we'd need the flag if we have a .handleState(tp.state)
|
||||||
ch <- res
|
ch <- res
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
|
158
server/tools.go
158
server/tools.go
@ -5,7 +5,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
|
||||||
"strings"
|
"strings"
|
||||||
gotmpl "text/template"
|
gotmpl "text/template"
|
||||||
|
|
||||||
@ -24,7 +23,9 @@ const (
|
|||||||
GreedyToolNoPrefix
|
GreedyToolNoPrefix
|
||||||
ForceTools
|
ForceTools
|
||||||
ToolSuffix
|
ToolSuffix
|
||||||
ContainsPartialPrefix
|
ContainsPrefix
|
||||||
|
PartialPrefix
|
||||||
|
NotPartialPrefix
|
||||||
Done
|
Done
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -64,9 +65,11 @@ func (s State) String() string {
|
|||||||
return "ForceTools"
|
return "ForceTools"
|
||||||
case ToolSuffix:
|
case ToolSuffix:
|
||||||
return "ToolSuffix"
|
return "ToolSuffix"
|
||||||
|
case PartialPrefix:
|
||||||
|
return "PossiblePrefix"
|
||||||
case Done:
|
case Done:
|
||||||
return "Done"
|
return "Done"
|
||||||
case ContainsPartialPrefix:
|
case ContainsPrefix:
|
||||||
return "PartialPrefix"
|
return "PartialPrefix"
|
||||||
default:
|
default:
|
||||||
return fmt.Sprintf("Unknown State (%d)", s)
|
return fmt.Sprintf("Unknown State (%d)", s)
|
||||||
@ -88,6 +91,8 @@ type ToolParser struct {
|
|||||||
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
||||||
// Returns parsed tool calls, a boolean indicating if the JSON is incomplete, and a boolean indicating if the tool calls were found
|
// Returns parsed tool calls, a boolean indicating if the JSON is incomplete, and a boolean indicating if the tool calls were found
|
||||||
func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
||||||
|
fmt.Printf("attempting to parse JSON tool calls: input=%s\n", s)
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if err := p.tmpl.Execute(&b, map[string][]api.ToolCall{
|
if err := p.tmpl.Execute(&b, map[string][]api.ToolCall{
|
||||||
"ToolCalls": {
|
"ToolCalls": {
|
||||||
@ -101,6 +106,7 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
|
fmt.Printf("failed to execute template: error=%v\n", err)
|
||||||
return nil, false, false
|
return nil, false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -108,6 +114,7 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
|||||||
var temp any
|
var temp any
|
||||||
err := jsonv2.Unmarshal(b.Bytes(), &temp)
|
err := jsonv2.Unmarshal(b.Bytes(), &temp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
fmt.Printf("failed to unmarshal template: error=%v\n", err)
|
||||||
return nil, false, false
|
return nil, false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -125,6 +132,7 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
|||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
// TODO: err or fallback
|
// TODO: err or fallback
|
||||||
|
fmt.Printf("collect encountered unknown type: type=%T\n", obj)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -142,6 +150,7 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
|||||||
templateObjects = collect(t)
|
templateObjects = collect(t)
|
||||||
}
|
}
|
||||||
if len(templateObjects) == 0 {
|
if len(templateObjects) == 0 {
|
||||||
|
fmt.Println("no template objects found")
|
||||||
return nil, false, false
|
return nil, false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -151,12 +160,15 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
|||||||
switch v.(type) {
|
switch v.(type) {
|
||||||
case string:
|
case string:
|
||||||
name = k
|
name = k
|
||||||
|
fmt.Printf("found name field: key=%s\n", k)
|
||||||
case map[string]any:
|
case map[string]any:
|
||||||
arguments = k
|
arguments = k
|
||||||
|
fmt.Printf("found arguments field: key=%s\n", k)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if name == "" || arguments == "" {
|
if name == "" || arguments == "" {
|
||||||
|
fmt.Printf("missing required fields: name_found=%v arguments_found=%v\n", name != "", arguments != "")
|
||||||
return nil, false, false
|
return nil, false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -165,18 +177,17 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
|||||||
dec := jsontext.NewDecoder(strings.NewReader(s))
|
dec := jsontext.NewDecoder(strings.NewReader(s))
|
||||||
if got, err := dec.ReadValue(); err == nil {
|
if got, err := dec.ReadValue(); err == nil {
|
||||||
s = got.String()
|
s = got.String()
|
||||||
|
fmt.Printf("decoded JSON value: value=%s\n", s)
|
||||||
}
|
}
|
||||||
|
|
||||||
var responseObjects any
|
var responseObjects any
|
||||||
err = jsonv2.Unmarshal([]byte(s), &responseObjects)
|
err = jsonv2.Unmarshal([]byte(s), &responseObjects)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, io.ErrUnexpectedEOF) || err.Error() == "unexpected end of JSON input" {
|
if errors.Is(err, io.ErrUnexpectedEOF) || err.Error() == "unexpected end of JSON input" {
|
||||||
fmt.Println("Detected partial or incomplete JSON.")
|
fmt.Println("incomplete JSON detected")
|
||||||
fmt.Println("state", p.state)
|
|
||||||
return nil, true, false
|
return nil, true, false
|
||||||
} else {
|
} else {
|
||||||
fmt.Printf("Other error: %v\n", err)
|
fmt.Printf("failed to unmarshal response: error=%v\n", err)
|
||||||
fmt.Println("exiting from JSON parsing", p.state)
|
|
||||||
return nil, false, false
|
return nil, false, false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -187,14 +198,14 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
|||||||
return nil, false, false
|
return nil, false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("collected objects", "count", len(objs))
|
fmt.Printf("collected objects: count=%d\n", len(objs))
|
||||||
|
|
||||||
var toolCalls []api.ToolCall
|
var toolCalls []api.ToolCall
|
||||||
for _, kv := range objs {
|
for _, kv := range objs {
|
||||||
n, nok := kv[name].(string)
|
n, nok := kv[name].(string)
|
||||||
a, aok := kv[arguments].(map[string]any)
|
a, aok := kv[arguments].(map[string]any)
|
||||||
if nok && aok {
|
if nok && aok {
|
||||||
slog.Debug("found valid tool call", "name", n)
|
fmt.Printf("found valid tool call: name=%s\n", n)
|
||||||
toolCalls = append(toolCalls, api.ToolCall{
|
toolCalls = append(toolCalls, api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: n,
|
Name: n,
|
||||||
@ -204,84 +215,89 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("parsed tool calls", "count", len(toolCalls))
|
fmt.Printf("parsed tool calls: count=%d\n", len(toolCalls))
|
||||||
return toolCalls, false, true
|
return toolCalls, false, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ToolParser) updateOutputState(ok bool, partial bool, tcs []api.ToolCall) {
|
// TODO: clean up the boundary of internal and external state transitions
|
||||||
|
func (p *ToolParser) updateStateAfterJSONParse(ok bool, partial bool, tcs []api.ToolCall) {
|
||||||
|
fmt.Printf("updating output state: ok=%v partial=%v tool_calls=%d current_state=%s\n", ok, partial, len(tcs), p.state)
|
||||||
|
|
||||||
|
// state transition logic
|
||||||
switch {
|
switch {
|
||||||
case !ok && !partial && p.state == ForceTools:
|
case !ok && !partial && p.state == ForceTools:
|
||||||
fmt.Println("Case: !ok && !partial && ForceTools - staying in force tools, resetting buffer")
|
|
||||||
// force partial tool if we have a prefix
|
// force partial tool if we have a prefix
|
||||||
// no op and stay in force tools
|
// no op and stay in force tools
|
||||||
p.sb.Reset()
|
p.sb.Reset()
|
||||||
case !ok && !partial:
|
case !ok && !partial:
|
||||||
fmt.Println("Case: !ok && !partial")
|
|
||||||
fmt.Println("state", p.state)
|
|
||||||
if p.state == GreedyToolNoPrefix {
|
if p.state == GreedyToolNoPrefix {
|
||||||
fmt.Println(" Subcase: GreedyToolNoPrefix - marking as done")
|
|
||||||
p.state = Done
|
p.state = Done
|
||||||
// p.ParserState = DoneFR
|
// ? the output parser state is the same even though internal can we not leak the external state?
|
||||||
p.ParserState = ToolCallSendTokens
|
|
||||||
p.Done = true
|
p.Done = true
|
||||||
}
|
}
|
||||||
if p.state == GreedyToolWithPrefix {
|
if p.state == GreedyToolWithPrefix {
|
||||||
fmt.Println(" Subcase: GreedyToolWithPrefix - switching to SendTokens")
|
|
||||||
p.state = SendTokens
|
p.state = SendTokens
|
||||||
p.ParserState = ToolCallSendTokens
|
|
||||||
}
|
}
|
||||||
p.sb.Reset()
|
if p.state == PartialPrefix {
|
||||||
|
p.state = NotPartialPrefix
|
||||||
|
}
|
||||||
case !ok && partial:
|
case !ok && partial:
|
||||||
fmt.Println("Case: !ok && partial - accumulating partial content")
|
// acucumulate
|
||||||
// ! acucumulate
|
|
||||||
|
|
||||||
case len(tcs) > 0:
|
case len(tcs) > 0:
|
||||||
fmt.Println("Case: tool calls found")
|
|
||||||
// do not parse again in the greedy JSON case as soon as we have a tool call
|
// do not parse again in the greedy JSON case as soon as we have a tool call
|
||||||
if p.state == GreedyToolWithPrefix {
|
|
||||||
p.state = SendTokens
|
|
||||||
p.ParserState = ToolCallFound
|
|
||||||
p.state = Done
|
|
||||||
p.Done = true
|
|
||||||
} else if p.state == GreedyToolNoPrefix {
|
|
||||||
fmt.Println(" Subcase: Greedy modes - marking done and switching to SendTokens")
|
|
||||||
p.state = Done
|
|
||||||
p.Done = true
|
|
||||||
}
|
|
||||||
p.sb.Reset()
|
p.sb.Reset()
|
||||||
}
|
}
|
||||||
p.updateExternalState(tcs)
|
p.updateExternalState(tcs)
|
||||||
|
fmt.Printf("state updated: new_state=%s parser_state=%s\n", p.state, p.ParserState)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ToolParser) updateExternalState(tcs []api.ToolCall) {
|
func (p *ToolParser) updateExternalState(tcs []api.ToolCall) {
|
||||||
if (p.state == GreedyToolWithPrefix || p.state == GreedyToolNoPrefix || p.state == ToolSuffix) || (p.state == ForceTools && len(tcs) == 0) {
|
fmt.Printf("updating external state: current_state=%s tool_calls=%d\n", p.state, len(tcs))
|
||||||
p.ParserState = ToolCallAccumulate
|
|
||||||
} else if p.state == ContainsPartialPrefix {
|
switch {
|
||||||
p.ParserState = ToolCallSendPartial
|
case len(tcs) > 0:
|
||||||
} else if len(tcs) > 0 {
|
// do not parse again in the greedy JSON case as soon as we have a tool call
|
||||||
|
if p.state == GreedyToolWithPrefix {
|
||||||
|
p.state = SendTokens
|
||||||
|
} else if p.state == GreedyToolNoPrefix {
|
||||||
|
p.state = Done
|
||||||
|
p.Done = true
|
||||||
|
}
|
||||||
p.ParserState = ToolCallFound
|
p.ParserState = ToolCallFound
|
||||||
} else if p.state == SendTokens {
|
case p.state == GreedyToolWithPrefix || p.state == GreedyToolNoPrefix ||
|
||||||
|
p.state == ToolSuffix || p.state == PartialPrefix ||
|
||||||
|
(p.state == ForceTools && len(tcs) == 0):
|
||||||
|
p.ParserState = ToolCallAccumulate
|
||||||
|
case p.state == ContainsPrefix:
|
||||||
|
p.ParserState = ToolCallSendPartial
|
||||||
|
case p.state == SendTokens || p.state == Done:
|
||||||
p.ParserState = ToolCallSendTokens
|
p.ParserState = ToolCallSendTokens
|
||||||
|
case p.state == NotPartialPrefix:
|
||||||
|
p.ParserState = ToolCallSendPartial
|
||||||
|
default:
|
||||||
|
p.ParserState = ToolCallSendTokens
|
||||||
|
p.sb.Reset()
|
||||||
|
p.state = SendTokens
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// string, and if it has a prefix
|
// string, and if it has a prefix
|
||||||
func (p *ToolParser) checkPrefix(s string) (string, bool) {
|
func (p *ToolParser) checkPrefix(s string) (string, bool) {
|
||||||
|
fmt.Printf("checking prefix: input=%s prefix=%s\n", s, p.toolPrefix)
|
||||||
|
|
||||||
if p.toolPrefix == "" {
|
if p.toolPrefix == "" {
|
||||||
return s, true
|
return s, true
|
||||||
}
|
}
|
||||||
original := s
|
original := s
|
||||||
// s = strings.TrimSpace(s)
|
|
||||||
s, hasPrefix := strings.CutPrefix(s, p.toolPrefix)
|
s, hasPrefix := strings.CutPrefix(s, p.toolPrefix)
|
||||||
if hasPrefix {
|
if hasPrefix {
|
||||||
fmt.Println("has prefix", s)
|
|
||||||
p.state = ForceTools
|
p.state = ForceTools
|
||||||
// partial tool possibly
|
fmt.Printf("found exact prefix match: remaining=%s\n", s)
|
||||||
} else if strings.HasPrefix(p.toolPrefix, s) {
|
// partial tool possibly - accumulate
|
||||||
slog.Debug("tool prefix partially", "prefix", p.toolPrefix, "content", s)
|
} else if suffixOverlap(s, p.toolPrefix) > 0 {
|
||||||
// TODO: could possibly err maybe this should be greedy instead?
|
p.state = PartialPrefix
|
||||||
p.state = ForceTools
|
fmt.Printf("found partial prefix: remaining=%s\n", s)
|
||||||
// this would basically be a no op on rest of the input
|
|
||||||
return "", false
|
return "", false
|
||||||
// the case where "token<tool_call>" - send "token" back
|
// the case where "token<tool_call>" - send "token" back
|
||||||
// accounts for spaces in prefix or suffix to avoid breaking cache
|
// accounts for spaces in prefix or suffix to avoid breaking cache
|
||||||
@ -289,11 +305,13 @@ func (p *ToolParser) checkPrefix(s string) (string, bool) {
|
|||||||
idx := strings.Index(original, p.toolPrefix)
|
idx := strings.Index(original, p.toolPrefix)
|
||||||
if idx != -1 {
|
if idx != -1 {
|
||||||
// still keeps the prefix
|
// still keeps the prefix
|
||||||
p.state = ContainsPartialPrefix
|
p.state = ContainsPrefix
|
||||||
p.sb.Reset()
|
p.sb.Reset()
|
||||||
// todo: see if there is a simpler way for this
|
// todo: see if there is a simpler way for this
|
||||||
idx2 := strings.Index(s, p.toolPrefix)
|
idx2 := strings.Index(s, p.toolPrefix)
|
||||||
|
// buffer now only has the prefix
|
||||||
p.sb.WriteString(s[idx2:])
|
p.sb.WriteString(s[idx2:])
|
||||||
|
fmt.Printf("found prefix in middle: prefix_start=%d content_before=%s\n", idx, original[:idx])
|
||||||
return original[:idx], false
|
return original[:idx], false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -305,51 +323,71 @@ func (p *ToolParser) checkPrefix(s string) (string, bool) {
|
|||||||
// ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing.
|
// ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing.
|
||||||
// Returns tool calls, whether parsing is incomplete, and any errors.
|
// Returns tool calls, whether parsing is incomplete, and any errors.
|
||||||
func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, string) {
|
func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, string) {
|
||||||
fmt.Println("checking tool calls", s)
|
fmt.Printf("parsing tool calls: input=%s current_state=%s\n", s, p.state)
|
||||||
fmt.Println("external state", p.ParserState)
|
|
||||||
fmt.Println("internal state", p.state)
|
|
||||||
p.sb.WriteString(s)
|
p.sb.WriteString(s)
|
||||||
s = p.sb.String()
|
s = p.sb.String()
|
||||||
|
|
||||||
s = strings.TrimSpace(s)
|
s = strings.TrimSpace(s)
|
||||||
fmt.Println("sb", s)
|
|
||||||
|
|
||||||
p.updateExternalState(nil)
|
|
||||||
if len(s) == 0 {
|
if len(s) == 0 {
|
||||||
|
p.updateExternalState(nil)
|
||||||
return nil, ""
|
return nil, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
s, cont := p.checkPrefix(s)
|
s, cont := p.checkPrefix(s)
|
||||||
if !cont {
|
if !cont {
|
||||||
p.updateExternalState(nil)
|
p.updateExternalState(nil)
|
||||||
if p.state == ContainsPartialPrefix {
|
if p.state == ContainsPrefix {
|
||||||
|
fmt.Printf("returning partial prefix: remaining=%s\n", s)
|
||||||
return nil, s
|
return nil, s
|
||||||
}
|
}
|
||||||
|
// * we'd be returning here for just accumulating with possible prefix
|
||||||
|
// * ext state is accumulation
|
||||||
return nil, ""
|
return nil, ""
|
||||||
}
|
}
|
||||||
|
// * lets say the check fails here and now we're still in external state accumulation here
|
||||||
|
|
||||||
// stay in SendTokens unless we have a prefix
|
// stay in SendTokens unless we have a prefix
|
||||||
if p.state == SendTokens {
|
if p.state == SendTokens {
|
||||||
fmt.Println("SendTokens - resetting buffer")
|
|
||||||
p.updateExternalState(nil)
|
p.updateExternalState(nil)
|
||||||
p.sb.Reset()
|
p.sb.Reset()
|
||||||
return nil, ""
|
fmt.Printf("returning send tokens: remaining=%s\n", s)
|
||||||
|
return nil, s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// * we'd parse here as json to see if it's a tool call
|
||||||
tcs, partial, ok := p.parseJSONToolCalls(s)
|
tcs, partial, ok := p.parseJSONToolCalls(s)
|
||||||
p.updateOutputState(ok, partial, tcs)
|
// * it would not be a tool call here
|
||||||
fmt.Println("output state", p.ParserState, p.state)
|
p.updateStateAfterJSONParse(ok, partial, tcs)
|
||||||
if !ok {
|
if !ok {
|
||||||
fmt.Println("returning empty tool calls")
|
// * and so we should send the data here
|
||||||
|
// * we also need to move out of that internal state after sending the tokens
|
||||||
|
if p.state == NotPartialPrefix {
|
||||||
|
p.state = SendTokens
|
||||||
|
// the string would have acc until here
|
||||||
|
return nil, p.sb.String()
|
||||||
|
}
|
||||||
return nil, ""
|
return nil, ""
|
||||||
}
|
}
|
||||||
for _, tc := range tcs {
|
for _, tc := range tcs {
|
||||||
tc.Function.Index = p.toolIndex
|
tc.Function.Index = p.toolIndex
|
||||||
p.toolIndex++
|
p.toolIndex++
|
||||||
}
|
}
|
||||||
|
fmt.Printf("finished parsing tool calls: tool_calls_found=%d\n", len(tcs))
|
||||||
return tcs, ""
|
return tcs, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func suffixOverlap(s, delim string) int {
|
||||||
|
max := min(len(delim), len(s))
|
||||||
|
for i := max; i > 0; i-- {
|
||||||
|
if strings.HasSuffix(s, delim[:i]) {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
func NewToolParser(model *Model) *ToolParser {
|
func NewToolParser(model *Model) *ToolParser {
|
||||||
// TODO: use new template parsing to get all tokens for the prefix
|
// TODO: use new template parsing to get all tokens for the prefix
|
||||||
templateToolPrefix, _ := ToolPrefix(model.Template.Template)
|
templateToolPrefix, _ := ToolPrefix(model.Template.Template)
|
||||||
@ -365,7 +403,7 @@ func NewToolParser(model *Model) *ToolParser {
|
|||||||
} else {
|
} else {
|
||||||
state = GreedyToolWithPrefix
|
state = GreedyToolWithPrefix
|
||||||
}
|
}
|
||||||
fmt.Println("setup state", state)
|
fmt.Printf("creating new tool parser: prefix=%s initial_state=%s\n", templateToolPrefix, state)
|
||||||
return &ToolParser{
|
return &ToolParser{
|
||||||
tmpl: tmpl,
|
tmpl: tmpl,
|
||||||
sb: &strings.Builder{},
|
sb: &strings.Builder{},
|
||||||
|
@ -55,21 +55,21 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
expectedTokens string
|
expectedTokens string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "mistral invalid json",
|
name: "mistral malformed json with tool calls prefix",
|
||||||
model: "mistral",
|
model: "mistral",
|
||||||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`,
|
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`,
|
||||||
expectedToolCall: []api.ToolCall{},
|
expectedToolCall: []api.ToolCall{},
|
||||||
expectedTokens: "",
|
expectedTokens: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "mistral multiple tool calls - no prefix",
|
name: "mistral multiple tool calls without prefix",
|
||||||
model: "mistral",
|
model: "mistral",
|
||||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
expectedToolCall: []api.ToolCall{t1, t2},
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
expectedTokens: "",
|
expectedTokens: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "mistral tool calls with text in between - no prefix",
|
name: "mistral tool calls with text between no prefix",
|
||||||
model: "mistral",
|
model: "mistral",
|
||||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||||
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
@ -77,15 +77,14 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
expectedTokens: `model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
expectedTokens: `model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "mistral valid json - with prefix",
|
name: "mistral valid json with tool calls prefix",
|
||||||
model: "mistral",
|
model: "mistral",
|
||||||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
expectedToolCall: []api.ToolCall{t1, t2},
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
expectedTokens: "",
|
expectedTokens: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// In this case we'd be ignoring the text in between and just returning the tool calls
|
name: "mistral multiple tool calls with text between and prefix",
|
||||||
name: "mistral valid json with text in between - with prefix",
|
|
||||||
model: "mistral",
|
model: "mistral",
|
||||||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||||
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
@ -93,14 +92,14 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
expectedTokens: "",
|
expectedTokens: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "mistral incomplete json",
|
name: "mistral incomplete json with tool calls prefix",
|
||||||
model: "mistral",
|
model: "mistral",
|
||||||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `,
|
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `,
|
||||||
expectedToolCall: []api.ToolCall{},
|
expectedToolCall: []api.ToolCall{},
|
||||||
expectedTokens: "",
|
expectedTokens: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "mistral without tool token",
|
name: "mistral invalid tool call with explanatory text no prefix",
|
||||||
model: "mistral",
|
model: "mistral",
|
||||||
output: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
|
output: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
|
||||||
|
|
||||||
@ -109,14 +108,14 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
expectedTokens: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
expectedTokens: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "mistral without tool token - tool first",
|
name: "mistral tool calls without prefix",
|
||||||
model: "mistral",
|
model: "mistral",
|
||||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
expectedToolCall: []api.ToolCall{t1, t2},
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
expectedTokens: "",
|
expectedTokens: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "command-r-plus with json block",
|
name: "command r plus tool calls with json block format",
|
||||||
model: "command-r-plus",
|
model: "command-r-plus",
|
||||||
output: "Action: ```json" + `
|
output: "Action: ```json" + `
|
||||||
[
|
[
|
||||||
@ -140,14 +139,14 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
expectedTokens: "",
|
expectedTokens: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "firefunction with functools",
|
name: "firefunction tool calls with functools prefix",
|
||||||
model: "firefunction",
|
model: "firefunction",
|
||||||
output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
expectedToolCall: []api.ToolCall{t1, t2},
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
expectedTokens: "",
|
expectedTokens: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "llama3 with tool call tags",
|
name: "llama3 groq single tool call with xml tags",
|
||||||
model: "llama3-groq-tool-use",
|
model: "llama3-groq-tool-use",
|
||||||
output: `<tool_call>
|
output: `<tool_call>
|
||||||
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
|
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
|
||||||
@ -156,99 +155,126 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
expectedTokens: "",
|
expectedTokens: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "xlam with tool_calls wrapper",
|
name: "xlam tool calls with wrapper object",
|
||||||
model: "xlam",
|
model: "xlam",
|
||||||
output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`,
|
output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`,
|
||||||
expectedToolCall: []api.ToolCall{t1, t2},
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
expectedTokens: "",
|
expectedTokens: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "qwen2.5 with single tool call",
|
name: "qwen2.5-coder single tool call with prefix",
|
||||||
model: "qwen2.5-coder",
|
model: "qwen2.5-coder",
|
||||||
output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
|
output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
|
||||||
expectedToolCall: []api.ToolCall{t1},
|
expectedToolCall: []api.ToolCall{t1},
|
||||||
expectedTokens: "",
|
expectedTokens: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "qwen with no tool prefix",
|
name: "qwen2.5-coder multiple tool calls with and without prefix",
|
||||||
|
model: "qwen2.5-coder",
|
||||||
|
output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} <tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call> <tool_call>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}</tool_call>`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t1, t2},
|
||||||
|
expectedTokens: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen2.5-coder multiple tool calls without prefix",
|
||||||
model: "qwen2.5-coder",
|
model: "qwen2.5-coder",
|
||||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
expectedToolCall: []api.ToolCall{t1, t2},
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
expectedTokens: "",
|
expectedTokens: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "qwen with no tool calls",
|
name: "qwen2.5-coder plain text response no tool calls",
|
||||||
model: "qwen2.5-coder",
|
model: "qwen2.5-coder",
|
||||||
output: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
|
output: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
|
||||||
expectedToolCall: []api.ToolCall{},
|
expectedToolCall: []api.ToolCall{},
|
||||||
expectedTokens: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
|
expectedTokens: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "qwen with no tool prefix",
|
name: "qwen2.5-coder tool calls with trailing text",
|
||||||
model: "qwen2.5-coder",
|
model: "qwen2.5-coder",
|
||||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after call`,
|
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after call`,
|
||||||
expectedToolCall: []api.ToolCall{t1, t2},
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
expectedTokens: "some tokens after call",
|
expectedTokens: "some tokens after call",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "qwen with prefix",
|
name: "qwen2.5 tool calls with prefix and trailing text",
|
||||||
model: "qwen2.5-coder",
|
model: "qwen2.5-coder",
|
||||||
output: `<tool_call> [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] </tool_call> some tokens after call`,
|
output: `<tool_call> [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] </tool_call> some tokens after call`,
|
||||||
expectedToolCall: []api.ToolCall{t1, t2},
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
expectedTokens: "",
|
expectedTokens: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// tests the leftover logic as well
|
name: "qwen3 tool call with think prefix and tool prefix (sent as a single token)",
|
||||||
name: "qwen3 with single tool call and thinking",
|
|
||||||
model: "qwen3",
|
model: "qwen3",
|
||||||
output: `<think>Okay, let me think what tool we should use...</think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
|
output: `<think>Okay, let me think what tool we should use...</think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
|
||||||
expectedToolCall: []api.ToolCall{t1},
|
expectedToolCall: []api.ToolCall{t1},
|
||||||
expectedTokens: "<think>Okay, let me think what tool we should use...</think>",
|
expectedTokens: "<think>Okay, let me think what tool we should use...</think>",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "qwen3 with single tool call and thinking spaces",
|
name: "qwen3 tool call with think prefix, tool prefix, and whitespace (sent as separate tokens)",
|
||||||
model: "qwen3",
|
model: "qwen3",
|
||||||
output: `<think>Okay, let me think what tool we should use...</think> <tool_call> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
output: `<think>Okay, let me think what tool we should use...</think> <tool_call> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
expectedToolCall: []api.ToolCall{t1},
|
expectedToolCall: []api.ToolCall{t1},
|
||||||
expectedTokens: "<think>Okay, let me think what tool we should use...</think>",
|
expectedTokens: "<think>Okay, let me think what tool we should use...</think>",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "qwen3 testing",
|
name: "qwen3 empty think prefix without tool prefix and invalid tool call",
|
||||||
model: "qwen3",
|
model: "qwen3",
|
||||||
output: `<think></think>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
output: `<think></think>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
expectedToolCall: []api.ToolCall{},
|
expectedToolCall: []api.ToolCall{},
|
||||||
expectedTokens: `<think></think>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
expectedTokens: `<think></think>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "qwen3 testing 2",
|
name: "qwen3 empty think prefix with tool prefix and valid tool call",
|
||||||
model: "qwen3",
|
model: "qwen3",
|
||||||
output: `<think></think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
output: `<think></think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
expectedToolCall: []api.ToolCall{t1},
|
expectedToolCall: []api.ToolCall{t1},
|
||||||
expectedTokens: `<think></think>`,
|
expectedTokens: `<think></think>`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "llama3.2 with tool call - no prefix",
|
name: "qwen3 invalid tool call with fake tool prefix (single rune suffix match)",
|
||||||
|
model: "qwen3",
|
||||||
|
output: `<think></think>< fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
|
expectedToolCall: []api.ToolCall{},
|
||||||
|
expectedTokens: `<think></think>< fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen3 invalid tool call with partial tool prefix (multiple rune suffix match)",
|
||||||
|
model: "qwen3",
|
||||||
|
output: `<think></think><tool_c fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
|
expectedToolCall: []api.ToolCall{},
|
||||||
|
expectedTokens: `<think></think><tool_c fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen3 invalid tool call with malformed tool prefix",
|
||||||
|
model: "qwen3",
|
||||||
|
output: `<think></think><tool_cfakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
|
expectedToolCall: []api.ToolCall{},
|
||||||
|
expectedTokens: `<think></think><tool_cfakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "llama3.2 valid tool call without prefix",
|
||||||
model: "llama3.2",
|
model: "llama3.2",
|
||||||
output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
|
output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
|
||||||
expectedToolCall: []api.ToolCall{t1},
|
expectedToolCall: []api.ToolCall{t1},
|
||||||
expectedTokens: "",
|
expectedTokens: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "llama3.2 with incomplete tool call - no prefix",
|
name: "llama3.2 incomplete tool call without prefix",
|
||||||
model: "llama3.2",
|
model: "llama3.2",
|
||||||
output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, `,
|
output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, `,
|
||||||
expectedToolCall: []api.ToolCall{},
|
expectedToolCall: []api.ToolCall{},
|
||||||
expectedTokens: "",
|
expectedTokens: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "llama3.2 with tool call - in middle",
|
name: "llama3.2 tool call with leading text",
|
||||||
model: "llama3.2",
|
model: "llama3.2",
|
||||||
output: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
|
output: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
|
||||||
expectedToolCall: []api.ToolCall{},
|
expectedToolCall: []api.ToolCall{},
|
||||||
expectedTokens: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
|
expectedTokens: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "llama3.2 - fake tool prefix",
|
name: "llama3.2 tool call with invalid tool prefix (no prefix in template)",
|
||||||
model: "llama3.2",
|
model: "llama3.2",
|
||||||
output: `<tool_call>{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
|
output: `<tool_call>{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
|
||||||
expectedToolCall: []api.ToolCall{},
|
expectedToolCall: []api.ToolCall{},
|
||||||
@ -288,7 +314,7 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
m := &Model{Template: tmpl}
|
m := &Model{Template: tmpl}
|
||||||
tp := NewToolParser(m)
|
tp := NewToolParser(m)
|
||||||
got := []api.ToolCall{}
|
got := []api.ToolCall{}
|
||||||
var actualTokens strings.Builder
|
var gotTokens strings.Builder
|
||||||
|
|
||||||
tokens := strings.Fields(tt.output)
|
tokens := strings.Fields(tt.output)
|
||||||
for _, tok := range tokens {
|
for _, tok := range tokens {
|
||||||
@ -302,17 +328,18 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
got = append(got, toolCalls...)
|
got = append(got, toolCalls...)
|
||||||
add = false
|
add = false
|
||||||
case ToolCallSendTokens:
|
case ToolCallSendTokens:
|
||||||
actualTokens.WriteString(s)
|
gotTokens.WriteString(s)
|
||||||
add = false
|
add = false
|
||||||
case ToolCallAccumulate:
|
case ToolCallAccumulate:
|
||||||
add = false
|
add = false
|
||||||
case ToolCallSendPartial:
|
case ToolCallSendPartial:
|
||||||
actualTokens.WriteString(" " + leftover)
|
t.Log("send partial", "leftover", leftover)
|
||||||
|
gotTokens.WriteString(" " + leftover)
|
||||||
add = false
|
add = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if add {
|
if add {
|
||||||
actualTokens.WriteString(s)
|
gotTokens.WriteString(s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -322,7 +349,7 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Compare tokens if we expect any
|
// Compare tokens if we expect any
|
||||||
stripped := strings.TrimSpace(actualTokens.String())
|
stripped := strings.TrimSpace(gotTokens.String())
|
||||||
if diff := cmp.Diff(stripped, tt.expectedTokens); diff != "" {
|
if diff := cmp.Diff(stripped, tt.expectedTokens); diff != "" {
|
||||||
t.Log("actualTokens", stripped, "expectedTokens", tt.expectedTokens)
|
t.Log("actualTokens", stripped, "expectedTokens", tt.expectedTokens)
|
||||||
t.Errorf("tokens mismatch (-got +want):\n%s", diff)
|
t.Errorf("tokens mismatch (-got +want):\n%s", diff)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user