From 779547fcdeeae6617b38f8aabd8669e43e59fe67 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Thu, 8 May 2025 18:48:44 -0700 Subject: [PATCH] checkpoint - cleanup still left, functionality setup --- server/routes.go | 43 +++------- server/tools.go | 199 +++++++++++++++++++++++++------------------ server/tools_test.go | 118 +++++++++++-------------- 3 files changed, 177 insertions(+), 183 deletions(-) diff --git a/server/routes.go b/server/routes.go index 3fdd24a61..47c1f4a06 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1486,10 +1486,9 @@ func (s *Server) ChatHandler(c *gin.Context) { go func() { defer close(ch) // var sb strings.Builder - var toolCallIndex int = 0 - var tp *ToolParser + var toolParser *ToolParser if len(req.Tools) > 0 { - tp = NewToolParser(m) + toolParser = NewToolParser(m) } if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ @@ -1517,37 +1516,19 @@ func (s *Server) ChatHandler(c *gin.Context) { res.LoadDuration = checkpointLoaded.Sub(checkpointStart) } - if len(req.Tools) > 0 && tp.state != Done { - fmt.Println("checking tool calls") - /* - This should give us a few return things we shouldnt have to build things up. - 1. tool calls if any - 2. leftover tokens if any - this happens in the partial case where we have a prefix inside a string - 3. if we need to skip this loop and not send anything back - - - between these three things, we should just be switching on either the state or something to capture this - potentially consider a difference between internal and external state - */ - toolCalls, leftover, ok := tp.ParseToolCalls(r.Content) - - // todo: this should just be one check/state coming back from the parse tool calls - if (tp.state == GreedyToolWithPrefix || tp.state == GreedyToolNoPrefix || tp.state == ToolSuffix) || (tp.state == ForceTools && len(toolCalls) == 0) { + if len(req.Tools) > 0 && !toolParser.Done { + toolCalls, leftover := toolParser.ParseToolCalls(r.Content) + switch toolParser.ParserState { + case ToolCallAccumulate: + // tokens are accumulated in the tool parser return - } - // todo: our second state also should just be a var - if tp.state == ContainsPartialPrefix { - fmt.Println("sending tokens", leftover) + case ToolCallSendTokens: + // tokens are sent back in the response + case ToolCallSendPartial: + // tokens not needed for parsing are sent back in the response res.Message.Content = leftover - } - // TODO: this can be done inside the parse tool calls - if ok && len(toolCalls) > 0 { + case ToolCallFound: res.Message.ToolCalls = toolCalls - for i := range toolCalls { - toolCalls[i].Function.Index = toolCallIndex - toolCallIndex++ - } - // Remove content when tool call is present res.Message.Content = "" } } diff --git a/server/tools.go b/server/tools.go index faa2cb9e2..3058456a1 100644 --- a/server/tools.go +++ b/server/tools.go @@ -10,12 +10,14 @@ import ( gotmpl "text/template" jsonv2 "github.com/go-json-experiment/json" + jsontext "github.com/go-json-experiment/json/jsontext" "github.com/ollama/ollama/api" ) type State int +// TODO: potentially coalesce states const ( SendTokens State = iota GreedyToolWithPrefix @@ -26,6 +28,30 @@ const ( Done ) +type ExternalState int + +const ( + ToolCallFound ExternalState = iota + ToolCallSendPartial + ToolCallAccumulate + ToolCallSendTokens +) + +func (s ExternalState) String() string { + switch s { + case ToolCallFound: + return "ToolCallFound" + case ToolCallSendPartial: + return "ToolCallSendPartial" + case ToolCallAccumulate: + return "ToolCallAccumulate" + case ToolCallSendTokens: + return "ToolCallSendTokens" + default: + return fmt.Sprintf("Unknown ExternalState (%d)", s) + } +} + func (s State) String() string { switch s { case SendTokens: @@ -47,13 +73,18 @@ func (s State) String() string { } } +// TODO: simplify if possible type ToolParser struct { - tmpl *gotmpl.Template - state State - sb *strings.Builder - toolPrefix string + tmpl *gotmpl.Template + state State + sb *strings.Builder + toolPrefix string + toolIndex int + ParserState ExternalState + Done bool } +// ? move to a separate file // parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls. // Returns parsed tool calls, a boolean indicating if the JSON is incomplete, and a boolean indicating if the tool calls were found func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) { @@ -73,7 +104,7 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) { return nil, false, false } - // ! this can be either a map or an array + // this can be either a map or an array var temp any err := jsonv2.Unmarshal(b.Bytes(), &temp) if err != nil { @@ -128,6 +159,14 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) { if name == "" || arguments == "" { return nil, false, false } + + // TODO: there is probably some underlying repeat work here to avoid + // This incrementally decodes the JSON string and returns the first parsedobject + dec := jsontext.NewDecoder(strings.NewReader(s)) + if got, err := dec.ReadValue(); err == nil { + s = got.String() + } + var responseObjects any err = jsonv2.Unmarshal([]byte(s), &responseObjects) if err != nil { @@ -137,7 +176,7 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) { return nil, true, false } else { fmt.Printf("Other error: %v\n", err) - fmt.Println("exiting", p.state) + fmt.Println("exiting from JSON parsing", p.state) return nil, false, false } } @@ -182,10 +221,14 @@ func (p *ToolParser) updateOutputState(ok bool, partial bool, tcs []api.ToolCall if p.state == GreedyToolNoPrefix { fmt.Println(" Subcase: GreedyToolNoPrefix - marking as done") p.state = Done + // p.ParserState = DoneFR + p.ParserState = ToolCallSendTokens + p.Done = true } if p.state == GreedyToolWithPrefix { fmt.Println(" Subcase: GreedyToolWithPrefix - switching to SendTokens") p.state = SendTokens + p.ParserState = ToolCallSendTokens } p.sb.Reset() case !ok && partial: @@ -197,131 +240,118 @@ func (p *ToolParser) updateOutputState(ok bool, partial bool, tcs []api.ToolCall // do not parse again in the greedy JSON case as soon as we have a tool call if p.state == GreedyToolWithPrefix { p.state = SendTokens + p.ParserState = ToolCallFound + p.state = Done + p.Done = true } else if p.state == GreedyToolNoPrefix { fmt.Println(" Subcase: Greedy modes - marking done and switching to SendTokens") p.state = Done + p.Done = true } p.sb.Reset() } + p.updateExternalState(tcs) } -func (p *ToolParser) updateInputState(s string, hasPrefix bool) (string, bool) { +func (p *ToolParser) updateExternalState(tcs []api.ToolCall) { + if (p.state == GreedyToolWithPrefix || p.state == GreedyToolNoPrefix || p.state == ToolSuffix) || (p.state == ForceTools && len(tcs) == 0) { + p.ParserState = ToolCallAccumulate + } else if p.state == ContainsPartialPrefix { + p.ParserState = ToolCallSendPartial + } else if len(tcs) > 0 { + p.ParserState = ToolCallFound + } else if p.state == SendTokens { + p.ParserState = ToolCallSendTokens + } +} + +// string, and if it has a prefix +func (p *ToolParser) checkPrefix(s string) (string, bool) { if p.toolPrefix == "" { return s, true } - + original := s + // s = strings.TrimSpace(s) + s, hasPrefix := strings.CutPrefix(s, p.toolPrefix) if hasPrefix { + fmt.Println("has prefix", s) p.state = ForceTools // partial tool possibly } else if strings.HasPrefix(p.toolPrefix, s) { slog.Debug("tool prefix partially", "prefix", p.toolPrefix, "content", s) // TODO: could possibly err maybe this should be greedy instead? p.state = ForceTools + // this would basically be a no op on rest of the input return "", false - } else if strings.Contains(s, p.toolPrefix) { - idx := strings.Index(s, p.toolPrefix) - if idx != -1 { - // still keeps the prefix - p.state = ContainsPartialPrefix - p.sb.Reset() - p.sb.WriteString(s[idx:]) - return s[:idx], false - } - } - // Special token end case - if strings.HasSuffix(s, p.toolPrefix[2:]) { - // can be with string or just the token - if hasPrefix { - s = strings.TrimSpace(s[:len(s)-(len(p.toolPrefix)+1)]) - } else { - p.state = ToolSuffix - p.sb.Reset() - return "", false - } - slog.Debug("setting to no tool", "content", s) - } - return s, true -} - -func (p *ToolParser) sendTokens(original string, hasPrefix bool) (string, bool) { - if p.state == SendTokens { - return "", false - } - if p.state == ContainsPartialPrefix { + // the case where "token" - send "token" back + // accounts for spaces in prefix or suffix to avoid breaking cache + } else if strings.Contains(original, p.toolPrefix) { idx := strings.Index(original, p.toolPrefix) if idx != -1 { // still keeps the prefix p.state = ContainsPartialPrefix - return original[:idx], false - } else { - fmt.Println("some weird state") - } - } else if strings.HasSuffix(original, p.toolPrefix[2:]) { - // can be with string or just the token - if hasPrefix { - original = strings.TrimSpace(original[:len(original)-(len(p.toolPrefix)+1)]) - return original, false - } else { - p.state = ToolSuffix p.sb.Reset() + // todo: see if there is a simpler way for this + idx2 := strings.Index(s, p.toolPrefix) + p.sb.WriteString(s[idx2:]) + return original[:idx], false } } - return "", true + return s, true } +// TODO: simplify the flow of this function // ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing. // Returns tool calls, whether parsing is incomplete, and any errors. -func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, string, bool) { - original := s - // append input +func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, string) { + fmt.Println("checking tool calls", s) + fmt.Println("external state", p.ParserState) + fmt.Println("internal state", p.state) p.sb.WriteString(s) s = p.sb.String() + s = strings.TrimSpace(s) + fmt.Println("sb", s) + p.updateExternalState(nil) if len(s) == 0 { - return nil, "", false + return nil, "" } - s, hasPrefix := strings.CutPrefix(s, p.toolPrefix) - - s, ok := p.updateInputState(s, hasPrefix) - if !ok { + s, cont := p.checkPrefix(s) + if !cont { + p.updateExternalState(nil) if p.state == ContainsPartialPrefix { - idx := strings.Index(original, p.toolPrefix) - if idx != -1 { - // still keeps the prefix - p.state = ContainsPartialPrefix - // p.sb.Reset() - // p.sb.WriteString(original[idx:]) - return nil, original[:idx], false - } else { - fmt.Println("some weird state") - } - // } - // s, ok = p.sendTokens(original, hasPrefix) - // if ok { - // return nil, s, true + return nil, s } - return nil, "", false + return nil, "" } + // stay in SendTokens unless we have a prefix if p.state == SendTokens { - return nil, "", false + fmt.Println("SendTokens - resetting buffer") + p.updateExternalState(nil) + p.sb.Reset() + return nil, "" } - var tcs []api.ToolCall - var partial bool - tcs, partial, ok = p.parseJSONToolCalls(s) + tcs, partial, ok := p.parseJSONToolCalls(s) p.updateOutputState(ok, partial, tcs) + fmt.Println("output state", p.ParserState, p.state) if !ok { - return nil, "", false + fmt.Println("returning empty tool calls") + return nil, "" } - - return tcs, "", true + for _, tc := range tcs { + tc.Function.Index = p.toolIndex + p.toolIndex++ + } + return tcs, "" } func NewToolParser(model *Model) *ToolParser { + // TODO: use new template parsing to get all tokens for the prefix templateToolPrefix, _ := ToolPrefix(model.Template.Template) templateToolPrefix = strings.TrimSpace(templateToolPrefix) tmpl, ok := ToolTemplate(model) @@ -337,9 +367,10 @@ func NewToolParser(model *Model) *ToolParser { } fmt.Println("setup state", state) return &ToolParser{ - tmpl: tmpl, - sb: &strings.Builder{}, - toolPrefix: templateToolPrefix, - state: state, + tmpl: tmpl, + sb: &strings.Builder{}, + toolPrefix: templateToolPrefix, + state: state, + ParserState: ToolCallAccumulate, } } diff --git a/server/tools_test.go b/server/tools_test.go index 29d057c75..675d1fa67 100644 --- a/server/tools_test.go +++ b/server/tools_test.go @@ -53,7 +53,6 @@ func TestParseToolCalls(t *testing.T) { output string expectedToolCall []api.ToolCall expectedTokens string - wantErr bool }{ { name: "mistral invalid json", @@ -61,7 +60,6 @@ func TestParseToolCalls(t *testing.T) { output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`, expectedToolCall: []api.ToolCall{}, expectedTokens: "", - wantErr: true, }, { name: "mistral multiple tool calls - no prefix", @@ -69,7 +67,6 @@ func TestParseToolCalls(t *testing.T) { output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, expectedToolCall: []api.ToolCall{t1, t2}, expectedTokens: "", - wantErr: false, }, { name: "mistral tool calls with text in between - no prefix", @@ -78,7 +75,6 @@ func TestParseToolCalls(t *testing.T) { model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, expectedToolCall: []api.ToolCall{t1, t2}, expectedTokens: `model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - wantErr: false, }, { name: "mistral valid json - with prefix", @@ -86,7 +82,6 @@ func TestParseToolCalls(t *testing.T) { output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, expectedToolCall: []api.ToolCall{t1, t2}, expectedTokens: "", - wantErr: false, }, { // In this case we'd be ignoring the text in between and just returning the tool calls @@ -96,7 +91,6 @@ func TestParseToolCalls(t *testing.T) { model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, expectedToolCall: []api.ToolCall{t1, t2, t1, t2}, expectedTokens: "", - wantErr: false, }, { name: "mistral incomplete json", @@ -104,7 +98,6 @@ func TestParseToolCalls(t *testing.T) { output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `, expectedToolCall: []api.ToolCall{}, expectedTokens: "", - wantErr: true, }, { name: "mistral without tool token", @@ -114,7 +107,6 @@ func TestParseToolCalls(t *testing.T) { [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, expectedToolCall: []api.ToolCall{}, expectedTokens: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - wantErr: true, }, { name: "mistral without tool token - tool first", @@ -122,7 +114,6 @@ func TestParseToolCalls(t *testing.T) { output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, expectedToolCall: []api.ToolCall{t1, t2}, expectedTokens: "", - wantErr: false, }, { name: "command-r-plus with json block", @@ -147,7 +138,6 @@ func TestParseToolCalls(t *testing.T) { ` + "```", expectedToolCall: []api.ToolCall{t1, t2}, expectedTokens: "", - wantErr: false, }, { name: "firefunction with functools", @@ -155,7 +145,6 @@ func TestParseToolCalls(t *testing.T) { output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, expectedToolCall: []api.ToolCall{t1, t2}, expectedTokens: "", - wantErr: false, }, { name: "llama3 with tool call tags", @@ -165,7 +154,6 @@ func TestParseToolCalls(t *testing.T) { `, expectedToolCall: []api.ToolCall{t1}, expectedTokens: "", - wantErr: false, }, { name: "xlam with tool_calls wrapper", @@ -173,7 +161,6 @@ func TestParseToolCalls(t *testing.T) { output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, expectedToolCall: []api.ToolCall{t1, t2}, expectedTokens: "", - wantErr: false, }, { name: "qwen2.5 with single tool call", @@ -181,15 +168,34 @@ func TestParseToolCalls(t *testing.T) { output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}`, expectedToolCall: []api.ToolCall{t1}, expectedTokens: "", - wantErr: false, }, { - name: "qwen with invalid tool token", + name: "qwen with no tool prefix", model: "qwen2.5-coder", output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, expectedToolCall: []api.ToolCall{t1, t2}, expectedTokens: "", - wantErr: false, + }, + { + name: "qwen with no tool calls", + model: "qwen2.5-coder", + output: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", + expectedToolCall: []api.ToolCall{}, + expectedTokens: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", + }, + { + name: "qwen with no tool prefix", + model: "qwen2.5-coder", + output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after call`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "some tokens after call", + }, + { + name: "qwen with prefix", + model: "qwen2.5-coder", + output: ` [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after call`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "", }, { // tests the leftover logic as well @@ -198,7 +204,6 @@ func TestParseToolCalls(t *testing.T) { output: `Okay, let me think what tool we should use...{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}`, expectedToolCall: []api.ToolCall{t1}, expectedTokens: "Okay, let me think what tool we should use...", - wantErr: false, }, { name: "qwen3 with single tool call and thinking spaces", @@ -206,31 +211,20 @@ func TestParseToolCalls(t *testing.T) { output: `Okay, let me think what tool we should use... {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, expectedToolCall: []api.ToolCall{t1}, expectedTokens: "Okay, let me think what tool we should use...", - wantErr: false, }, - // { - // name: "qwen3 testing", - // model: "qwen3", - // output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, - // expectedToolCall: []api.ToolCall{}, - // expectedTokens: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, - // wantErr: true, - // }, - // { - // name: "qwen3 testing 2", - // model: "qwen3", - // output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, - // expectedToolCall: []api.ToolCall{t1}, - // expectedTokens: ``, - // wantErr: true, - // }, { - name: "qwen with no tool calls", - model: "qwen2.5-coder", - output: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", + name: "qwen3 testing", + model: "qwen3", + output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, expectedToolCall: []api.ToolCall{}, - expectedTokens: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", - wantErr: true, + expectedTokens: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + }, + { + name: "qwen3 testing 2", + model: "qwen3", + output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: ``, }, { name: "llama3.2 with tool call - no prefix", @@ -238,7 +232,6 @@ func TestParseToolCalls(t *testing.T) { output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`, expectedToolCall: []api.ToolCall{t1}, expectedTokens: "", - wantErr: false, }, { name: "llama3.2 with incomplete tool call - no prefix", @@ -246,7 +239,6 @@ func TestParseToolCalls(t *testing.T) { output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, `, expectedToolCall: []api.ToolCall{}, expectedTokens: "", - wantErr: true, }, { name: "llama3.2 with tool call - in middle", @@ -254,7 +246,6 @@ func TestParseToolCalls(t *testing.T) { output: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`, expectedToolCall: []api.ToolCall{}, expectedTokens: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`, - wantErr: true, }, { name: "llama3.2 - fake tool prefix", @@ -262,7 +253,6 @@ func TestParseToolCalls(t *testing.T) { output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`, expectedToolCall: []api.ToolCall{}, expectedTokens: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`, - wantErr: true, }, } @@ -298,7 +288,6 @@ func TestParseToolCalls(t *testing.T) { m := &Model{Template: tmpl} tp := NewToolParser(m) got := []api.ToolCall{} - success := false var actualTokens strings.Builder tokens := strings.Fields(tt.output) @@ -306,40 +295,33 @@ func TestParseToolCalls(t *testing.T) { add := true s := " " + tok - // TODO(parthsareen): This logic is brittle as it mocks the logic in route, however can - if tp.state != Done { - toolCalls, leftover, ok := tp.ParseToolCalls(s) - if (tp.state == GreedyToolWithPrefix || tp.state == GreedyToolNoPrefix || tp.state == ToolSuffix) || (tp.state == ForceTools && len(toolCalls) == 0) { - continue - } - if tp.state == ContainsPartialPrefix { - // actualTokens.Reset() - actualTokens.WriteString(leftover) - t.Log("leftover", leftover) - add = false - // continue - } - if ok && len(toolCalls) > 0 { - success = true + if !tp.Done { + toolCalls, leftover := tp.ParseToolCalls(s) + switch tp.ParserState { + case ToolCallFound: got = append(got, toolCalls...) add = false - // actualTokens.Reset() + case ToolCallSendTokens: + actualTokens.WriteString(s) + add = false + case ToolCallAccumulate: + add = false + case ToolCallSendPartial: + actualTokens.WriteString(" " + leftover) + add = false } } - // s = strings.TrimSpace(s) if add { actualTokens.WriteString(s) } } - if !tt.wantErr { - if diff := cmp.Diff(got, tt.expectedToolCall); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) - } - } - if !success && !tt.wantErr { - t.Errorf("expected success but got errors") + // Compare tool calls if we expect any + if diff := cmp.Diff(got, tt.expectedToolCall); diff != "" { + t.Errorf("tool calls mismatch (-got +want):\n%s", diff) } + + // Compare tokens if we expect any stripped := strings.TrimSpace(actualTokens.String()) if diff := cmp.Diff(stripped, tt.expectedTokens); diff != "" { t.Log("actualTokens", stripped, "expectedTokens", tt.expectedTokens)