checkpoint - cleanup still left, functionality setup

This commit is contained in:
ParthSareen 2025-05-08 18:48:44 -07:00
parent 6cb7494061
commit 779547fcde
3 changed files with 177 additions and 183 deletions

View File

@ -1486,10 +1486,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
go func() {
defer close(ch)
// var sb strings.Builder
var toolCallIndex int = 0
var tp *ToolParser
var toolParser *ToolParser
if len(req.Tools) > 0 {
tp = NewToolParser(m)
toolParser = NewToolParser(m)
}
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
@ -1517,37 +1516,19 @@ func (s *Server) ChatHandler(c *gin.Context) {
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
if len(req.Tools) > 0 && tp.state != Done {
fmt.Println("checking tool calls")
/*
This should give us a few return things we shouldnt have to build things up.
1. tool calls if any
2. leftover tokens if any - this happens in the partial case where we have a prefix inside a string
3. if we need to skip this loop and not send anything back
between these three things, we should just be switching on either the state or something to capture this
potentially consider a difference between internal and external state
*/
toolCalls, leftover, ok := tp.ParseToolCalls(r.Content)
// todo: this should just be one check/state coming back from the parse tool calls
if (tp.state == GreedyToolWithPrefix || tp.state == GreedyToolNoPrefix || tp.state == ToolSuffix) || (tp.state == ForceTools && len(toolCalls) == 0) {
if len(req.Tools) > 0 && !toolParser.Done {
toolCalls, leftover := toolParser.ParseToolCalls(r.Content)
switch toolParser.ParserState {
case ToolCallAccumulate:
// tokens are accumulated in the tool parser
return
}
// todo: our second state also should just be a var
if tp.state == ContainsPartialPrefix {
fmt.Println("sending tokens", leftover)
case ToolCallSendTokens:
// tokens are sent back in the response
case ToolCallSendPartial:
// tokens not needed for parsing are sent back in the response
res.Message.Content = leftover
}
// TODO: this can be done inside the parse tool calls
if ok && len(toolCalls) > 0 {
case ToolCallFound:
res.Message.ToolCalls = toolCalls
for i := range toolCalls {
toolCalls[i].Function.Index = toolCallIndex
toolCallIndex++
}
// Remove content when tool call is present
res.Message.Content = ""
}
}

View File

@ -10,12 +10,14 @@ import (
gotmpl "text/template"
jsonv2 "github.com/go-json-experiment/json"
jsontext "github.com/go-json-experiment/json/jsontext"
"github.com/ollama/ollama/api"
)
type State int
// TODO: potentially coalesce states
const (
SendTokens State = iota
GreedyToolWithPrefix
@ -26,6 +28,30 @@ const (
Done
)
type ExternalState int
const (
ToolCallFound ExternalState = iota
ToolCallSendPartial
ToolCallAccumulate
ToolCallSendTokens
)
func (s ExternalState) String() string {
switch s {
case ToolCallFound:
return "ToolCallFound"
case ToolCallSendPartial:
return "ToolCallSendPartial"
case ToolCallAccumulate:
return "ToolCallAccumulate"
case ToolCallSendTokens:
return "ToolCallSendTokens"
default:
return fmt.Sprintf("Unknown ExternalState (%d)", s)
}
}
func (s State) String() string {
switch s {
case SendTokens:
@ -47,13 +73,18 @@ func (s State) String() string {
}
}
// TODO: simplify if possible
type ToolParser struct {
tmpl *gotmpl.Template
state State
sb *strings.Builder
toolPrefix string
tmpl *gotmpl.Template
state State
sb *strings.Builder
toolPrefix string
toolIndex int
ParserState ExternalState
Done bool
}
// ? move to a separate file
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
// Returns parsed tool calls, a boolean indicating if the JSON is incomplete, and a boolean indicating if the tool calls were found
func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
@ -73,7 +104,7 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
return nil, false, false
}
// ! this can be either a map or an array
// this can be either a map or an array
var temp any
err := jsonv2.Unmarshal(b.Bytes(), &temp)
if err != nil {
@ -128,6 +159,14 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
if name == "" || arguments == "" {
return nil, false, false
}
// TODO: there is probably some underlying repeat work here to avoid
// This incrementally decodes the JSON string and returns the first parsedobject
dec := jsontext.NewDecoder(strings.NewReader(s))
if got, err := dec.ReadValue(); err == nil {
s = got.String()
}
var responseObjects any
err = jsonv2.Unmarshal([]byte(s), &responseObjects)
if err != nil {
@ -137,7 +176,7 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
return nil, true, false
} else {
fmt.Printf("Other error: %v\n", err)
fmt.Println("exiting", p.state)
fmt.Println("exiting from JSON parsing", p.state)
return nil, false, false
}
}
@ -182,10 +221,14 @@ func (p *ToolParser) updateOutputState(ok bool, partial bool, tcs []api.ToolCall
if p.state == GreedyToolNoPrefix {
fmt.Println(" Subcase: GreedyToolNoPrefix - marking as done")
p.state = Done
// p.ParserState = DoneFR
p.ParserState = ToolCallSendTokens
p.Done = true
}
if p.state == GreedyToolWithPrefix {
fmt.Println(" Subcase: GreedyToolWithPrefix - switching to SendTokens")
p.state = SendTokens
p.ParserState = ToolCallSendTokens
}
p.sb.Reset()
case !ok && partial:
@ -197,131 +240,118 @@ func (p *ToolParser) updateOutputState(ok bool, partial bool, tcs []api.ToolCall
// do not parse again in the greedy JSON case as soon as we have a tool call
if p.state == GreedyToolWithPrefix {
p.state = SendTokens
p.ParserState = ToolCallFound
p.state = Done
p.Done = true
} else if p.state == GreedyToolNoPrefix {
fmt.Println(" Subcase: Greedy modes - marking done and switching to SendTokens")
p.state = Done
p.Done = true
}
p.sb.Reset()
}
p.updateExternalState(tcs)
}
func (p *ToolParser) updateInputState(s string, hasPrefix bool) (string, bool) {
func (p *ToolParser) updateExternalState(tcs []api.ToolCall) {
if (p.state == GreedyToolWithPrefix || p.state == GreedyToolNoPrefix || p.state == ToolSuffix) || (p.state == ForceTools && len(tcs) == 0) {
p.ParserState = ToolCallAccumulate
} else if p.state == ContainsPartialPrefix {
p.ParserState = ToolCallSendPartial
} else if len(tcs) > 0 {
p.ParserState = ToolCallFound
} else if p.state == SendTokens {
p.ParserState = ToolCallSendTokens
}
}
// string, and if it has a prefix
func (p *ToolParser) checkPrefix(s string) (string, bool) {
if p.toolPrefix == "" {
return s, true
}
original := s
// s = strings.TrimSpace(s)
s, hasPrefix := strings.CutPrefix(s, p.toolPrefix)
if hasPrefix {
fmt.Println("has prefix", s)
p.state = ForceTools
// partial tool possibly
} else if strings.HasPrefix(p.toolPrefix, s) {
slog.Debug("tool prefix partially", "prefix", p.toolPrefix, "content", s)
// TODO: could possibly err maybe this should be greedy instead?
p.state = ForceTools
// this would basically be a no op on rest of the input
return "", false
} else if strings.Contains(s, p.toolPrefix) {
idx := strings.Index(s, p.toolPrefix)
if idx != -1 {
// still keeps the prefix
p.state = ContainsPartialPrefix
p.sb.Reset()
p.sb.WriteString(s[idx:])
return s[:idx], false
}
}
// Special token end case
if strings.HasSuffix(s, p.toolPrefix[2:]) {
// can be with string or just the token
if hasPrefix {
s = strings.TrimSpace(s[:len(s)-(len(p.toolPrefix)+1)])
} else {
p.state = ToolSuffix
p.sb.Reset()
return "", false
}
slog.Debug("setting to no tool", "content", s)
}
return s, true
}
func (p *ToolParser) sendTokens(original string, hasPrefix bool) (string, bool) {
if p.state == SendTokens {
return "", false
}
if p.state == ContainsPartialPrefix {
// the case where "token<tool_call>" - send "token" back
// accounts for spaces in prefix or suffix to avoid breaking cache
} else if strings.Contains(original, p.toolPrefix) {
idx := strings.Index(original, p.toolPrefix)
if idx != -1 {
// still keeps the prefix
p.state = ContainsPartialPrefix
return original[:idx], false
} else {
fmt.Println("some weird state")
}
} else if strings.HasSuffix(original, p.toolPrefix[2:]) {
// can be with string or just the token
if hasPrefix {
original = strings.TrimSpace(original[:len(original)-(len(p.toolPrefix)+1)])
return original, false
} else {
p.state = ToolSuffix
p.sb.Reset()
// todo: see if there is a simpler way for this
idx2 := strings.Index(s, p.toolPrefix)
p.sb.WriteString(s[idx2:])
return original[:idx], false
}
}
return "", true
return s, true
}
// TODO: simplify the flow of this function
// ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing.
// Returns tool calls, whether parsing is incomplete, and any errors.
func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, string, bool) {
original := s
// append input
func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, string) {
fmt.Println("checking tool calls", s)
fmt.Println("external state", p.ParserState)
fmt.Println("internal state", p.state)
p.sb.WriteString(s)
s = p.sb.String()
s = strings.TrimSpace(s)
fmt.Println("sb", s)
p.updateExternalState(nil)
if len(s) == 0 {
return nil, "", false
return nil, ""
}
s, hasPrefix := strings.CutPrefix(s, p.toolPrefix)
s, ok := p.updateInputState(s, hasPrefix)
if !ok {
s, cont := p.checkPrefix(s)
if !cont {
p.updateExternalState(nil)
if p.state == ContainsPartialPrefix {
idx := strings.Index(original, p.toolPrefix)
if idx != -1 {
// still keeps the prefix
p.state = ContainsPartialPrefix
// p.sb.Reset()
// p.sb.WriteString(original[idx:])
return nil, original[:idx], false
} else {
fmt.Println("some weird state")
}
// }
// s, ok = p.sendTokens(original, hasPrefix)
// if ok {
// return nil, s, true
return nil, s
}
return nil, "", false
return nil, ""
}
// stay in SendTokens unless we have a prefix
if p.state == SendTokens {
return nil, "", false
fmt.Println("SendTokens - resetting buffer")
p.updateExternalState(nil)
p.sb.Reset()
return nil, ""
}
var tcs []api.ToolCall
var partial bool
tcs, partial, ok = p.parseJSONToolCalls(s)
tcs, partial, ok := p.parseJSONToolCalls(s)
p.updateOutputState(ok, partial, tcs)
fmt.Println("output state", p.ParserState, p.state)
if !ok {
return nil, "", false
fmt.Println("returning empty tool calls")
return nil, ""
}
return tcs, "", true
for _, tc := range tcs {
tc.Function.Index = p.toolIndex
p.toolIndex++
}
return tcs, ""
}
func NewToolParser(model *Model) *ToolParser {
// TODO: use new template parsing to get all tokens for the prefix
templateToolPrefix, _ := ToolPrefix(model.Template.Template)
templateToolPrefix = strings.TrimSpace(templateToolPrefix)
tmpl, ok := ToolTemplate(model)
@ -337,9 +367,10 @@ func NewToolParser(model *Model) *ToolParser {
}
fmt.Println("setup state", state)
return &ToolParser{
tmpl: tmpl,
sb: &strings.Builder{},
toolPrefix: templateToolPrefix,
state: state,
tmpl: tmpl,
sb: &strings.Builder{},
toolPrefix: templateToolPrefix,
state: state,
ParserState: ToolCallAccumulate,
}
}

View File

@ -53,7 +53,6 @@ func TestParseToolCalls(t *testing.T) {
output string
expectedToolCall []api.ToolCall
expectedTokens string
wantErr bool
}{
{
name: "mistral invalid json",
@ -61,7 +60,6 @@ func TestParseToolCalls(t *testing.T) {
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`,
expectedToolCall: []api.ToolCall{},
expectedTokens: "",
wantErr: true,
},
{
name: "mistral multiple tool calls - no prefix",
@ -69,7 +67,6 @@ func TestParseToolCalls(t *testing.T) {
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
wantErr: false,
},
{
name: "mistral tool calls with text in between - no prefix",
@ -78,7 +75,6 @@ func TestParseToolCalls(t *testing.T) {
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: `model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
wantErr: false,
},
{
name: "mistral valid json - with prefix",
@ -86,7 +82,6 @@ func TestParseToolCalls(t *testing.T) {
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
wantErr: false,
},
{
// In this case we'd be ignoring the text in between and just returning the tool calls
@ -96,7 +91,6 @@ func TestParseToolCalls(t *testing.T) {
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2, t1, t2},
expectedTokens: "",
wantErr: false,
},
{
name: "mistral incomplete json",
@ -104,7 +98,6 @@ func TestParseToolCalls(t *testing.T) {
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `,
expectedToolCall: []api.ToolCall{},
expectedTokens: "",
wantErr: true,
},
{
name: "mistral without tool token",
@ -114,7 +107,6 @@ func TestParseToolCalls(t *testing.T) {
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{},
expectedTokens: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
wantErr: true,
},
{
name: "mistral without tool token - tool first",
@ -122,7 +114,6 @@ func TestParseToolCalls(t *testing.T) {
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
wantErr: false,
},
{
name: "command-r-plus with json block",
@ -147,7 +138,6 @@ func TestParseToolCalls(t *testing.T) {
` + "```",
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
wantErr: false,
},
{
name: "firefunction with functools",
@ -155,7 +145,6 @@ func TestParseToolCalls(t *testing.T) {
output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
wantErr: false,
},
{
name: "llama3 with tool call tags",
@ -165,7 +154,6 @@ func TestParseToolCalls(t *testing.T) {
</tool_call>`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "",
wantErr: false,
},
{
name: "xlam with tool_calls wrapper",
@ -173,7 +161,6 @@ func TestParseToolCalls(t *testing.T) {
output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
wantErr: false,
},
{
name: "qwen2.5 with single tool call",
@ -181,15 +168,34 @@ func TestParseToolCalls(t *testing.T) {
output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "",
wantErr: false,
},
{
name: "qwen with invalid tool token",
name: "qwen with no tool prefix",
model: "qwen2.5-coder",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
wantErr: false,
},
{
name: "qwen with no tool calls",
model: "qwen2.5-coder",
output: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
expectedToolCall: []api.ToolCall{},
expectedTokens: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
},
{
name: "qwen with no tool prefix",
model: "qwen2.5-coder",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after call`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "some tokens after call",
},
{
name: "qwen with prefix",
model: "qwen2.5-coder",
output: `<tool_call> [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] </tool_call> some tokens after call`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
},
{
// tests the leftover logic as well
@ -198,7 +204,6 @@ func TestParseToolCalls(t *testing.T) {
output: `<think>Okay, let me think what tool we should use...</think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "<think>Okay, let me think what tool we should use...</think>",
wantErr: false,
},
{
name: "qwen3 with single tool call and thinking spaces",
@ -206,31 +211,20 @@ func TestParseToolCalls(t *testing.T) {
output: `<think>Okay, let me think what tool we should use...</think> <tool_call> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "<think>Okay, let me think what tool we should use...</think>",
wantErr: false,
},
// {
// name: "qwen3 testing",
// model: "qwen3",
// output: `<think></think>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
// expectedToolCall: []api.ToolCall{},
// expectedTokens: `<think></think>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
// wantErr: true,
// },
// {
// name: "qwen3 testing 2",
// model: "qwen3",
// output: `<think></think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
// expectedToolCall: []api.ToolCall{t1},
// expectedTokens: `<think></think>`,
// wantErr: true,
// },
{
name: "qwen with no tool calls",
model: "qwen2.5-coder",
output: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
name: "qwen3 testing",
model: "qwen3",
output: `<think></think>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
expectedToolCall: []api.ToolCall{},
expectedTokens: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
wantErr: true,
expectedTokens: `<think></think>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
},
{
name: "qwen3 testing 2",
model: "qwen3",
output: `<think></think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: `<think></think>`,
},
{
name: "llama3.2 with tool call - no prefix",
@ -238,7 +232,6 @@ func TestParseToolCalls(t *testing.T) {
output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "",
wantErr: false,
},
{
name: "llama3.2 with incomplete tool call - no prefix",
@ -246,7 +239,6 @@ func TestParseToolCalls(t *testing.T) {
output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, `,
expectedToolCall: []api.ToolCall{},
expectedTokens: "",
wantErr: true,
},
{
name: "llama3.2 with tool call - in middle",
@ -254,7 +246,6 @@ func TestParseToolCalls(t *testing.T) {
output: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
expectedToolCall: []api.ToolCall{},
expectedTokens: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
wantErr: true,
},
{
name: "llama3.2 - fake tool prefix",
@ -262,7 +253,6 @@ func TestParseToolCalls(t *testing.T) {
output: `<tool_call>{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
expectedToolCall: []api.ToolCall{},
expectedTokens: `<tool_call>{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
wantErr: true,
},
}
@ -298,7 +288,6 @@ func TestParseToolCalls(t *testing.T) {
m := &Model{Template: tmpl}
tp := NewToolParser(m)
got := []api.ToolCall{}
success := false
var actualTokens strings.Builder
tokens := strings.Fields(tt.output)
@ -306,40 +295,33 @@ func TestParseToolCalls(t *testing.T) {
add := true
s := " " + tok
// TODO(parthsareen): This logic is brittle as it mocks the logic in route, however can
if tp.state != Done {
toolCalls, leftover, ok := tp.ParseToolCalls(s)
if (tp.state == GreedyToolWithPrefix || tp.state == GreedyToolNoPrefix || tp.state == ToolSuffix) || (tp.state == ForceTools && len(toolCalls) == 0) {
continue
}
if tp.state == ContainsPartialPrefix {
// actualTokens.Reset()
actualTokens.WriteString(leftover)
t.Log("leftover", leftover)
add = false
// continue
}
if ok && len(toolCalls) > 0 {
success = true
if !tp.Done {
toolCalls, leftover := tp.ParseToolCalls(s)
switch tp.ParserState {
case ToolCallFound:
got = append(got, toolCalls...)
add = false
// actualTokens.Reset()
case ToolCallSendTokens:
actualTokens.WriteString(s)
add = false
case ToolCallAccumulate:
add = false
case ToolCallSendPartial:
actualTokens.WriteString(" " + leftover)
add = false
}
}
// s = strings.TrimSpace(s)
if add {
actualTokens.WriteString(s)
}
}
if !tt.wantErr {
if diff := cmp.Diff(got, tt.expectedToolCall); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
if !success && !tt.wantErr {
t.Errorf("expected success but got errors")
// Compare tool calls if we expect any
if diff := cmp.Diff(got, tt.expectedToolCall); diff != "" {
t.Errorf("tool calls mismatch (-got +want):\n%s", diff)
}
// Compare tokens if we expect any
stripped := strings.TrimSpace(actualTokens.String())
if diff := cmp.Diff(stripped, tt.expectedTokens); diff != "" {
t.Log("actualTokens", stripped, "expectedTokens", tt.expectedTokens)