add tests, organize, comments
This commit is contained in:
parent
bc83789be9
commit
8ed95a4e96
@ -1529,9 +1529,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
if len(content) > 0 {
|
if len(content) > 0 {
|
||||||
res.Message.Content = content
|
res.Message.Content = content
|
||||||
fmt.Println("sending content in response", content)
|
slog.Debug("tools: setting content to", "content", content)
|
||||||
} else if len(toolCalls) > 0 {
|
} else if len(toolCalls) > 0 {
|
||||||
fmt.Println("sending tool calls in response", toolCalls)
|
|
||||||
res.Message.ToolCalls = toolCalls
|
res.Message.ToolCalls = toolCalls
|
||||||
res.Message.Content = ""
|
res.Message.Content = ""
|
||||||
} else {
|
} else {
|
||||||
|
280
tools/tools.go
280
tools/tools.go
@ -1,9 +1,7 @@
|
|||||||
package tools
|
package tools
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"strings"
|
"strings"
|
||||||
@ -16,136 +14,56 @@ import (
|
|||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: simplify if possible
|
|
||||||
type Parser struct {
|
type Parser struct {
|
||||||
greedy bool
|
greedyParse bool
|
||||||
prefixFound bool
|
prefixFound bool
|
||||||
partialPrefix bool
|
prefixPartial bool
|
||||||
tmpl *gotmpl.Template
|
tmpl *gotmpl.Template
|
||||||
sb *strings.Builder
|
sb *strings.Builder
|
||||||
prefix string
|
prefix string
|
||||||
index int
|
index int
|
||||||
|
name string
|
||||||
|
arguments string
|
||||||
Done bool
|
Done bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
||||||
// Returns parsed tool calls, a boolean indicating if the JSON is incomplete, and a boolean indicating if the tool calls were found
|
// It first tries to incrementally decode the JSON to handle partial inputs.
|
||||||
func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
// Returns:
|
||||||
fmt.Printf("attempting to parse JSON tool calls: input=%s\n", s)
|
// - []api.ToolCall: The parsed tool calls if successful
|
||||||
|
// - bool: True if JSON is incomplete and needs more input
|
||||||
var b bytes.Buffer
|
func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool) {
|
||||||
if err := p.tmpl.Execute(&b, map[string][]api.ToolCall{
|
// First try incremental decoding to handle partial JSON
|
||||||
"ToolCalls": {
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Name: "@@name@@",
|
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
|
||||||
"@@argument@@": 1,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}); err != nil {
|
|
||||||
fmt.Printf("failed to execute template: error=%v\n", err)
|
|
||||||
return nil, false, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// this can be either a map or an array
|
|
||||||
var temp any
|
|
||||||
err := jsonv2.Unmarshal(b.Bytes(), &temp)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("failed to unmarshal template: error=%v\n", err)
|
|
||||||
return nil, false, false
|
|
||||||
}
|
|
||||||
|
|
||||||
var collect func(any) []map[string]any
|
|
||||||
collect = func(obj any) (all []map[string]any) {
|
|
||||||
switch o := obj.(type) {
|
|
||||||
case map[string]any:
|
|
||||||
all = append(all, o)
|
|
||||||
for _, v := range o {
|
|
||||||
all = append(all, collect(v)...)
|
|
||||||
}
|
|
||||||
case []any:
|
|
||||||
for _, v := range o {
|
|
||||||
all = append(all, collect(v)...)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
// TODO: err or fallback
|
|
||||||
fmt.Printf("collect encountered unknown type: type=%T\n", obj)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return all
|
|
||||||
}
|
|
||||||
|
|
||||||
var templateObjects []map[string]any
|
|
||||||
switch t := temp.(type) {
|
|
||||||
case map[string]any:
|
|
||||||
templateObjects = []map[string]any{t}
|
|
||||||
case []map[string]any:
|
|
||||||
templateObjects = t
|
|
||||||
// ! fallback?
|
|
||||||
case []any:
|
|
||||||
templateObjects = collect(t)
|
|
||||||
}
|
|
||||||
if len(templateObjects) == 0 {
|
|
||||||
fmt.Println("no template objects found")
|
|
||||||
return nil, false, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// find the keys that correspond to the name and arguments fields
|
|
||||||
var name, arguments string
|
|
||||||
for k, v := range templateObjects[0] {
|
|
||||||
switch v.(type) {
|
|
||||||
case string:
|
|
||||||
name = k
|
|
||||||
fmt.Printf("found name field: key=%s\n", k)
|
|
||||||
case map[string]any:
|
|
||||||
arguments = k
|
|
||||||
fmt.Printf("found arguments field: key=%s\n", k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if name == "" || arguments == "" {
|
|
||||||
fmt.Printf("missing required fields: name_found=%v arguments_found=%v\n", name != "", arguments != "")
|
|
||||||
return nil, false, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: there is probably some underlying repeat work here to avoid
|
|
||||||
// This incrementally decodes the JSON string and returns the first parsedobject
|
|
||||||
dec := jsontext.NewDecoder(strings.NewReader(s))
|
dec := jsontext.NewDecoder(strings.NewReader(s))
|
||||||
if got, err := dec.ReadValue(); err == nil {
|
if got, err := dec.ReadValue(); err == nil {
|
||||||
s = got.String()
|
s = got.String()
|
||||||
fmt.Printf("decoded JSON value: value=%s\n", s)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var responseObjects any
|
// Attempt full unmarshal of the JSON
|
||||||
err = jsonv2.Unmarshal([]byte(s), &responseObjects)
|
var resp any
|
||||||
|
err := jsonv2.Unmarshal([]byte(s), &resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// Handle incomplete JSON cases
|
||||||
if errors.Is(err, io.ErrUnexpectedEOF) || err.Error() == "unexpected end of JSON input" {
|
if errors.Is(err, io.ErrUnexpectedEOF) || err.Error() == "unexpected end of JSON input" {
|
||||||
fmt.Println("incomplete JSON detected")
|
slog.Debug("incomplete JSON detected", "input", s)
|
||||||
return nil, true, false
|
return nil, true
|
||||||
} else {
|
|
||||||
fmt.Printf("failed to unmarshal response: error=%v\n", err)
|
|
||||||
return nil, false, false
|
|
||||||
}
|
}
|
||||||
|
slog.Debug("failed to unmarshal response", "error", err)
|
||||||
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Collect all nested objects that could contain tool calls
|
||||||
var objs []map[string]any
|
var objs []map[string]any
|
||||||
objs = append(objs, collect(responseObjects)...)
|
objs = append(objs, collect(resp)...)
|
||||||
if len(objs) == 0 {
|
if len(objs) == 0 {
|
||||||
return nil, false, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("collected objects: count=%d\n", len(objs))
|
|
||||||
|
|
||||||
var toolCalls []api.ToolCall
|
var toolCalls []api.ToolCall
|
||||||
for _, kv := range objs {
|
for _, kv := range objs {
|
||||||
n, nok := kv[name].(string)
|
n, nok := kv[p.name].(string)
|
||||||
a, aok := kv[arguments].(map[string]any)
|
a, aok := kv[p.arguments].(map[string]any)
|
||||||
if nok && aok {
|
if nok && aok {
|
||||||
fmt.Printf("found valid tool call: name=%s\n", n)
|
|
||||||
toolCalls = append(toolCalls, api.ToolCall{
|
toolCalls = append(toolCalls, api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: n,
|
Name: n,
|
||||||
@ -155,130 +73,170 @@ func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("parsed tool calls: count=%d\n", len(toolCalls))
|
// Valid JSON, no tool calls found
|
||||||
return toolCalls, false, true
|
if len(toolCalls) == 0 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return toolCalls, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// prefix stripped string if any, prefix found, and if we should accumulate
|
// checkPrefix processes a string to find and handle a prefix pattern.
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - The processed string with prefix removed if found
|
||||||
|
// - Whether the prefix was found at the start of the string
|
||||||
|
// - Whether to continue parsing
|
||||||
func (p *Parser) checkPrefix(s string) (string, bool, bool) {
|
func (p *Parser) checkPrefix(s string) (string, bool, bool) {
|
||||||
|
// Keep original for overlap checks
|
||||||
|
original := s
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
if s == "" {
|
||||||
|
return "", false, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no prefix defined, just return trimmed string
|
||||||
if p.prefix == "" {
|
if p.prefix == "" {
|
||||||
return s, false, true
|
return s, false, true
|
||||||
}
|
}
|
||||||
original := s
|
|
||||||
s = strings.TrimSpace(s)
|
|
||||||
s, hasPrefix := strings.CutPrefix(s, p.prefix)
|
|
||||||
if hasPrefix {
|
|
||||||
// partial tool possibly - accumulate
|
|
||||||
return s, true, true
|
|
||||||
} else if overlap := suffixOverlap(original, p.prefix); overlap > 0 {
|
|
||||||
// p.state = PartialPrefix
|
|
||||||
p.partialPrefix = true
|
|
||||||
return original[0 : len(original)-overlap], false, false
|
|
||||||
} else if idx := strings.Index(original, p.prefix); idx != -1 {
|
|
||||||
// Found prefix in middle of string, keep only content before prefix
|
|
||||||
// accounts for spaces in prefix or suffix to avoid breaking cache
|
|
||||||
p.partialPrefix = true
|
|
||||||
p.sb.Reset()
|
|
||||||
|
|
||||||
|
// Check for prefix at start of string
|
||||||
|
if processedStr, hasPrefix := strings.CutPrefix(s, p.prefix); hasPrefix {
|
||||||
|
// Found prefix at start - accumulate for potential tool
|
||||||
|
return processedStr, true, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if prefix overlaps end of string
|
||||||
|
if overlap := suffixOverlap(original, p.prefix); overlap > 0 {
|
||||||
|
p.prefixPartial = true
|
||||||
|
// Return everything except overlapping portion
|
||||||
|
p.sb.Reset()
|
||||||
|
p.sb.WriteString(original[len(original)-overlap:])
|
||||||
|
return original[0 : len(original)-overlap], false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if prefix appears in middle of string
|
||||||
|
if idx := strings.Index(original, p.prefix); idx != -1 {
|
||||||
|
p.prefixPartial = true
|
||||||
|
// Save remainder starting at prefix for next pass
|
||||||
|
p.sb.Reset()
|
||||||
p.sb.WriteString(strings.TrimSpace(original[idx:]))
|
p.sb.WriteString(strings.TrimSpace(original[idx:]))
|
||||||
|
// Return everything before prefix
|
||||||
return original[:idx], false, false
|
return original[:idx], false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
p.partialPrefix = false
|
// No prefix found
|
||||||
|
p.prefixPartial = false
|
||||||
return s, false, true
|
return s, false, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add processes a string input to parse tool calls and content.
|
||||||
|
// It handles prefix detection and JSON parsing to extract tool calls.
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - tools: Any parsed tool calls
|
||||||
|
// - content: Non-tool call content
|
||||||
|
// - err: Error if parsing failed
|
||||||
func (p *Parser) Add(s string) (tools []api.ToolCall, content string, err error) {
|
func (p *Parser) Add(s string) (tools []api.ToolCall, content string, err error) {
|
||||||
slog.Debug("adding tool calls", "input", s)
|
|
||||||
|
|
||||||
p.sb.WriteString(s)
|
p.sb.WriteString(s)
|
||||||
s = p.sb.String()
|
s = p.sb.String()
|
||||||
|
|
||||||
if len(s) == 0 {
|
if len(s) == 0 {
|
||||||
return nil, "", nil
|
return nil, "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
s, prefixFound, cont := p.checkPrefix(s)
|
// Check for prefix pattern in input
|
||||||
|
s, prefixFound, shouldContinue := p.checkPrefix(s)
|
||||||
if !cont {
|
if !shouldContinue {
|
||||||
if s != "" {
|
if s != "" {
|
||||||
// send only the content back, prefix exists
|
// Return content before prefix
|
||||||
return nil, s, nil
|
return nil, s, nil
|
||||||
}
|
}
|
||||||
// accumulate case
|
// Need more input to complete prefix
|
||||||
return nil, "", nil
|
return nil, "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// circuit breaker
|
// Update prefix found state
|
||||||
if prefixFound {
|
if prefixFound {
|
||||||
p.prefixFound = true
|
p.prefixFound = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// for cases with a prefix in template
|
// Exit if prefix exists in template, greedy parsing is off, and prefix not found
|
||||||
if p.prefix != "" && !p.greedy && !p.prefixFound {
|
if !p.greedyParse && !p.prefixFound {
|
||||||
// send tokens down
|
|
||||||
p.sb.Reset()
|
p.sb.Reset()
|
||||||
return nil, "", errors.New("prefix not found")
|
return nil, "", errors.New("prefix not found")
|
||||||
}
|
}
|
||||||
// we have a prefix or are in json mode
|
|
||||||
tcs, partial, ok := p.parseJSONToolCalls(s)
|
toolCalls, isPartial := p.parseJSONToolCalls(s)
|
||||||
if partial {
|
if isPartial {
|
||||||
// accumulate case
|
// Need more input to complete JSON
|
||||||
return nil, "", nil
|
return nil, "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
p.greedy = false
|
// Do not try greedy parsing if partial JSON not found
|
||||||
if !ok {
|
p.greedyParse = false
|
||||||
// will not be a partial at this point
|
|
||||||
|
// Handle invalid tool call format
|
||||||
|
if len(toolCalls) == 0 {
|
||||||
p.sb.Reset()
|
p.sb.Reset()
|
||||||
// send tokens
|
|
||||||
if p.prefix == "" {
|
if p.prefix == "" {
|
||||||
p.Done = true
|
p.Done = true
|
||||||
}
|
}
|
||||||
if p.prefixFound {
|
if p.prefixFound {
|
||||||
// drop tokens instead - sb is reset, no tokens sent to user
|
// Drop tokens since prefix was found
|
||||||
return nil, "", nil
|
return nil, "", nil
|
||||||
}
|
}
|
||||||
return nil, "", errors.New("failed to parse tool calls")
|
return nil, s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tcs {
|
for _, tc := range toolCalls {
|
||||||
tc.Function.Index = p.index
|
tc.Function.Index = p.index
|
||||||
p.index++
|
p.index++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Mark as done if no prefix needed
|
||||||
if p.prefix == "" {
|
if p.prefix == "" {
|
||||||
p.Done = true
|
p.Done = true
|
||||||
}
|
}
|
||||||
|
|
||||||
p.sb.Reset()
|
p.sb.Reset()
|
||||||
return tcs, "", nil
|
return toolCalls, "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewParser creates a new tool call parser from a template. It extracts the tool call format,
|
||||||
|
// prefix, and field names from the template to use for parsing tool calls from model output.
|
||||||
|
//
|
||||||
|
// Returns an error if the template does not contain valid tool call formatting.
|
||||||
func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) {
|
func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) {
|
||||||
parsedTemplate, err := template.Parse(templateToProcess.Root.String())
|
parsed, err := template.Parse(templateToProcess.Root.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if parsedTemplate == nil {
|
if parsed == nil {
|
||||||
return nil, errors.New("failed to parse template")
|
return nil, errors.New("failed to parse template")
|
||||||
}
|
}
|
||||||
|
|
||||||
toolCallTemplate, hasToolCalls := toolTemplate(parsedTemplate)
|
tt, tc := toolTemplate(parsed)
|
||||||
if !hasToolCalls {
|
if !tc {
|
||||||
return nil, errors.New("failed to find tool template")
|
return nil, errors.New("failed to find tool calls in template")
|
||||||
}
|
}
|
||||||
if toolCallTemplate == nil {
|
if tt == nil {
|
||||||
return nil, errors.New("failed to find tool template")
|
return nil, errors.New("failed to find tool template")
|
||||||
}
|
}
|
||||||
|
|
||||||
toolPrefix, _ := ToolPrefix(templateToProcess)
|
tp := toolPrefix(templateToProcess)
|
||||||
toolPrefix = strings.TrimSpace(toolPrefix)
|
tp = strings.TrimSpace(tp)
|
||||||
|
|
||||||
|
name, arguments, err := extractToolArgs(tt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
fmt.Printf("creating new tool parser: prefix=%s\n", toolPrefix)
|
|
||||||
return &Parser{
|
return &Parser{
|
||||||
tmpl: toolCallTemplate,
|
tmpl: tt,
|
||||||
sb: &strings.Builder{},
|
sb: &strings.Builder{},
|
||||||
prefix: toolPrefix,
|
prefix: tp,
|
||||||
greedy: true,
|
greedyParse: true,
|
||||||
|
name: name,
|
||||||
|
arguments: arguments,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -6,11 +6,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
gotmpl "text/template"
|
|
||||||
"text/template/parse"
|
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
@ -206,6 +203,27 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
expectedToolCall: []api.ToolCall{t1, t2},
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
expectedTokens: "",
|
expectedTokens: "",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "qwen2.5 tool calls without prefix and valid tool call",
|
||||||
|
model: "qwen2.5-coder",
|
||||||
|
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||||
|
expectedToolCall: []api.ToolCall{t1, t2},
|
||||||
|
expectedTokens: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen2.5 tool calls without prefix and invalid tool call",
|
||||||
|
model: "qwen2.5-coder",
|
||||||
|
output: `[{"options": "foo"}]`,
|
||||||
|
expectedToolCall: []api.ToolCall{},
|
||||||
|
expectedTokens: `[{"options": "foo"}]`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen2.5 tool calls with prefix and invalid tool call",
|
||||||
|
model: "qwen2.5-coder",
|
||||||
|
output: `<tool_call> [{"options": "foo"}] </tool_call> `,
|
||||||
|
expectedToolCall: []api.ToolCall{},
|
||||||
|
expectedTokens: ``,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "qwen3 tool call with think prefix and tool prefix (sent as a single token)",
|
name: "qwen3 tool call with think prefix and tool prefix (sent as a single token)",
|
||||||
model: "qwen3",
|
model: "qwen3",
|
||||||
@ -239,14 +257,14 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
model: "qwen3",
|
model: "qwen3",
|
||||||
output: `<think></think>< fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
output: `<think></think>< fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
expectedToolCall: []api.ToolCall{},
|
expectedToolCall: []api.ToolCall{},
|
||||||
expectedTokens: `<think></think> fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
expectedTokens: `<think></think>< fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "qwen3 invalid tool call with partial tool prefix (multiple rune suffix match)",
|
name: "qwen3 invalid tool call with partial tool prefix (multiple rune suffix match)",
|
||||||
model: "qwen3",
|
model: "qwen3",
|
||||||
output: `<think></think><tool_c fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
output: `<think></think><tool_c fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
expectedToolCall: []api.ToolCall{},
|
expectedToolCall: []api.ToolCall{},
|
||||||
expectedTokens: `<think></think> fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
expectedTokens: `<think></think><tool_c fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "qwen3 invalid tool call with malformed tool prefix",
|
name: "qwen3 invalid tool call with malformed tool prefix",
|
||||||
@ -332,6 +350,7 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
toolCalls, content, err := tp.Add(s)
|
toolCalls, content, err := tp.Add(s)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if content != "" {
|
if content != "" {
|
||||||
|
fmt.Printf("content: %q\n", content)
|
||||||
gotTokens.WriteString(content)
|
gotTokens.WriteString(content)
|
||||||
add = false
|
add = false
|
||||||
} else if len(toolCalls) > 0 {
|
} else if len(toolCalls) > 0 {
|
||||||
@ -363,24 +382,101 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toolTemplateHelper(t *testing.T, tmpl *template.Template) (*gotmpl.Template, bool) {
|
func TestParseJSONToolCalls(t *testing.T) {
|
||||||
// create a subtree from the node that ranges over .ToolCalls
|
tests := []struct {
|
||||||
|
name string
|
||||||
tmpl2 := tmpl.Subtree(func(n parse.Node) bool {
|
input string
|
||||||
if t, ok := n.(*parse.RangeNode); ok {
|
parser *Parser
|
||||||
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
wantToolCalls []api.ToolCall
|
||||||
}
|
wantPartial bool
|
||||||
|
wantValid bool
|
||||||
return false
|
}{
|
||||||
})
|
{
|
||||||
|
name: "valid single tool call",
|
||||||
if tmpl2.Root != nil {
|
input: `{"name": "test_tool", "arguments": {"arg1": "value1"}}`,
|
||||||
t.Log("tmpl2", tmpl2.Root.String())
|
parser: &Parser{name: "name", arguments: "arguments"},
|
||||||
|
wantToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "test_tool",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"arg1": "value1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantPartial: false,
|
||||||
|
wantValid: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "incomplete JSON",
|
||||||
|
input: `{"name": "test_tool", "arguments": {"arg1": `,
|
||||||
|
parser: &Parser{name: "name", arguments: "arguments"},
|
||||||
|
wantToolCalls: nil,
|
||||||
|
wantPartial: true,
|
||||||
|
wantValid: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid JSON",
|
||||||
|
input: `not json at all`,
|
||||||
|
parser: &Parser{name: "name", arguments: "arguments"},
|
||||||
|
wantToolCalls: nil,
|
||||||
|
wantPartial: false,
|
||||||
|
wantValid: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing required fields",
|
||||||
|
input: `{"other": "field"}`,
|
||||||
|
parser: &Parser{name: "name", arguments: "arguments"},
|
||||||
|
wantToolCalls: nil,
|
||||||
|
wantPartial: false,
|
||||||
|
wantValid: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple tool calls in array",
|
||||||
|
input: `[
|
||||||
|
{"name": "tool1", "arguments": {"arg1": 1}},
|
||||||
|
{"name": "tool2", "arguments": {"arg2": "value"}}
|
||||||
|
]`,
|
||||||
|
parser: &Parser{name: "name", arguments: "arguments"},
|
||||||
|
wantToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "tool1",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"arg1": float64(1),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "tool2",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"arg2": "value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantPartial: false,
|
||||||
|
wantValid: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if tmpl2 == nil {
|
for _, tt := range tests {
|
||||||
return nil, false
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
}
|
gotCalls, gotPartial := tt.parser.parseJSONToolCalls(tt.input)
|
||||||
|
|
||||||
return tmpl2, true
|
if gotPartial != tt.wantPartial {
|
||||||
|
t.Errorf("parseJSONToolCalls() partial = %v, want %v", gotPartial, tt.wantPartial)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(gotCalls) != 0 != tt.wantValid {
|
||||||
|
t.Errorf("parseJSONToolCalls() valid = %v, want %v", len(gotCalls) == 0, tt.wantValid)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(gotCalls, tt.wantToolCalls); diff != "" {
|
||||||
|
t.Errorf("parseJSONToolCalls() tool calls mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
132
tools/utils.go
132
tools/utils.go
@ -1,17 +1,27 @@
|
|||||||
package tools
|
package tools
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
gotmpl "text/template"
|
gotmpl "text/template"
|
||||||
"text/template/parse"
|
"text/template/parse"
|
||||||
|
|
||||||
|
jsonv2 "github.com/go-json-experiment/json"
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
// extractToolCallsTemplate finds the immediate following text after any IfNode containing ".ToolCalls"
|
// extractToolCallsFormat traverses a template AST to find text that follows a ".ToolCalls" condition.
|
||||||
func extractToolCallsTemplate(tmpl *gotmpl.Template) (string, bool) {
|
// It walks the template nodes looking for if-statements containing ".ToolCalls" and extracts any
|
||||||
|
// immediate text nodes that follow. This is used to identify tool call prefixes and formatting.
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - string: The extracted text following the first ".ToolCalls" condition found
|
||||||
|
// - bool: Whether a ".ToolCalls" condition was found in the template
|
||||||
|
func extractToolCallsFormat(tmpl *gotmpl.Template) (string, bool) {
|
||||||
if tmpl == nil || tmpl.Tree == nil {
|
if tmpl == nil || tmpl.Tree == nil {
|
||||||
slog.Debug("TextAfterToolCalls: template or tree is nil")
|
slog.Debug("TextAfterToolCalls: template or tree is nil")
|
||||||
return "", false
|
return "", false
|
||||||
@ -29,7 +39,7 @@ func extractToolCallsTemplate(tmpl *gotmpl.Template) (string, bool) {
|
|||||||
|
|
||||||
switch n := node.(type) {
|
switch n := node.(type) {
|
||||||
case *parse.IfNode:
|
case *parse.IfNode:
|
||||||
if nodeContainsToolCalls(n) {
|
if isToolCallsNode(n) {
|
||||||
// Collect immediate TextNode(s) at start of IfNode's list
|
// Collect immediate TextNode(s) at start of IfNode's list
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
for _, innerNode := range n.List.Nodes {
|
for _, innerNode := range n.List.Nodes {
|
||||||
@ -76,8 +86,8 @@ func extractToolCallsTemplate(tmpl *gotmpl.Template) (string, bool) {
|
|||||||
return result, found
|
return result, found
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper to detect if a node's condition includes ".ToolCalls"
|
// isToolCallsNode detects if a node's condition includes ".ToolCalls"
|
||||||
func nodeContainsToolCalls(n *parse.IfNode) bool {
|
func isToolCallsNode(n *parse.IfNode) bool {
|
||||||
for _, cmd := range n.Pipe.Cmds {
|
for _, cmd := range n.Pipe.Cmds {
|
||||||
for _, arg := range cmd.Args {
|
for _, arg := range cmd.Args {
|
||||||
if field, ok := arg.(*parse.FieldNode); ok {
|
if field, ok := arg.(*parse.FieldNode); ok {
|
||||||
@ -90,16 +100,17 @@ func nodeContainsToolCalls(n *parse.IfNode) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToolPrefix returns the prefix for the tool call if it exists
|
|
||||||
// TODO(parthsareen): get full prefix from the template instead of just the first token
|
// TODO(parthsareen): get full prefix from the template instead of just the first token
|
||||||
func ToolPrefix(tmpl *gotmpl.Template) (string, bool) {
|
|
||||||
tokenText, ok := extractToolCallsTemplate(tmpl)
|
// toolPrefix returns the prefix for the tool call if it exists from a template
|
||||||
|
func toolPrefix(tmpl *gotmpl.Template) string {
|
||||||
|
tokenText, ok := extractToolCallsFormat(tmpl)
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", false
|
return ""
|
||||||
}
|
}
|
||||||
tokenText = strings.TrimSpace(tokenText)
|
tokenText = strings.TrimSpace(tokenText)
|
||||||
if tokenText == "" {
|
if tokenText == "" {
|
||||||
return "", false
|
return ""
|
||||||
}
|
}
|
||||||
first := strings.Fields(tokenText)[0]
|
first := strings.Fields(tokenText)[0]
|
||||||
|
|
||||||
@ -116,19 +127,23 @@ func ToolPrefix(tmpl *gotmpl.Template) (string, bool) {
|
|||||||
}
|
}
|
||||||
if start != -1 && end != -1 {
|
if start != -1 && end != -1 {
|
||||||
// return the token including the [ or < and the ] or >
|
// return the token including the [ or < and the ] or >
|
||||||
return tokenText[start : end+1], true
|
return tokenText[start : end+1]
|
||||||
} else if start != -1 {
|
} else if start != -1 {
|
||||||
// get until the [ or < - in the case tag was not closed
|
// get until the [ or < - in the case tag was not closed
|
||||||
return tokenText[:start], true
|
return tokenText[:start]
|
||||||
} else if end != -1 {
|
} else if end != -1 {
|
||||||
// get after the ] or > - in the case tag was not opened
|
// get after the ] or > - in the case tag was not opened
|
||||||
return tokenText[end+1:], true
|
return tokenText[end+1:]
|
||||||
}
|
}
|
||||||
return first, true
|
return first
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// toolTemplate creates a subtree from the node that ranges over .ToolCalls
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *gotmpl.Template: The subtree containing the .ToolCalls range
|
||||||
|
// - bool: Whether a .ToolCalls range was found in the template
|
||||||
func toolTemplate(t *template.Template) (*gotmpl.Template, bool) {
|
func toolTemplate(t *template.Template) (*gotmpl.Template, bool) {
|
||||||
// create a subtree from the node that ranges over .ToolCalls
|
|
||||||
tmpl := t.Subtree(func(n parse.Node) bool {
|
tmpl := t.Subtree(func(n parse.Node) bool {
|
||||||
if t, ok := n.(*parse.RangeNode); ok {
|
if t, ok := n.(*parse.RangeNode); ok {
|
||||||
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
||||||
@ -144,6 +159,10 @@ func toolTemplate(t *template.Template) (*gotmpl.Template, bool) {
|
|||||||
return tmpl, true
|
return tmpl, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// suffixOverlap returns the length of the longest suffix overlap between two strings
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - int: The length of the longest suffix overlap
|
||||||
func suffixOverlap(s, delim string) int {
|
func suffixOverlap(s, delim string) int {
|
||||||
max := min(len(delim), len(s))
|
max := min(len(delim), len(s))
|
||||||
for i := max; i > 0; i-- {
|
for i := max; i > 0; i-- {
|
||||||
@ -153,3 +172,86 @@ func suffixOverlap(s, delim string) int {
|
|||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractToolArgs executes a template with a known tool call format to extract the name and arguments
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - string: The name of the tool call
|
||||||
|
// - string: The arguments of the tool call
|
||||||
|
// - error: Error if parsing failed
|
||||||
|
func extractToolArgs(tmpl *gotmpl.Template) (name, arguments string, err error) {
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
|
||||||
|
"ToolCalls": {
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "@@name@@",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"@@argument@@": 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
var obj any
|
||||||
|
err = jsonv2.Unmarshal(b.Bytes(), &obj)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
var objs []map[string]any
|
||||||
|
switch v := obj.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
objs = []map[string]any{v}
|
||||||
|
case []map[string]any:
|
||||||
|
objs = v
|
||||||
|
case []any:
|
||||||
|
objs = collect(v)
|
||||||
|
}
|
||||||
|
if len(objs) == 0 {
|
||||||
|
return "", "", errors.New("no template objects found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// find the keys that correspond to the name and arguments fields
|
||||||
|
for k, v := range objs[0] {
|
||||||
|
switch v.(type) {
|
||||||
|
case string:
|
||||||
|
name = k
|
||||||
|
case map[string]any:
|
||||||
|
arguments = k
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if name == "" || arguments == "" {
|
||||||
|
slog.Debug("missing required fields in tool call template", "name", name, "arguments", arguments)
|
||||||
|
return "", "", errors.New("missing required fields in tool call template")
|
||||||
|
}
|
||||||
|
|
||||||
|
return name, arguments, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// collect recursively traverses an object to collect all nested maps
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - []map[string]any: A slice of all nested maps found in the object
|
||||||
|
func collect(obj any) []map[string]any {
|
||||||
|
var all []map[string]any
|
||||||
|
switch o := obj.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
all = append(all, o)
|
||||||
|
for _, v := range o {
|
||||||
|
all = append(all, collect(v)...)
|
||||||
|
}
|
||||||
|
case []any:
|
||||||
|
for _, v := range o {
|
||||||
|
all = append(all, collect(v)...)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return all
|
||||||
|
}
|
||||||
|
@ -3,74 +3,133 @@ package tools
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
gotmpl "text/template"
|
gotmpl "text/template"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestExtractToolCallsFormat(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
template string
|
||||||
|
want string
|
||||||
|
found bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil template",
|
||||||
|
template: "",
|
||||||
|
want: "",
|
||||||
|
found: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "basic tool call with text",
|
||||||
|
template: "{{if .ToolCalls}}Hello world{{end}}",
|
||||||
|
want: "Hello world",
|
||||||
|
found: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with json format",
|
||||||
|
template: "{{if .ToolCalls}}```json\n{{end}}",
|
||||||
|
want: "```json\n",
|
||||||
|
found: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call in range",
|
||||||
|
template: "{{range .ToolCalls}}tool: {{.}}{{end}}",
|
||||||
|
want: "",
|
||||||
|
found: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with multiple text nodes",
|
||||||
|
template: "{{if .ToolCalls}}First text{{if .Something}}inner{{end}}Second text{{end}}",
|
||||||
|
want: "First text",
|
||||||
|
found: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested if without tool calls",
|
||||||
|
template: "{{if .Something}}{{if .OtherThing}}text{{end}}{{end}}",
|
||||||
|
want: "",
|
||||||
|
found: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
tmpl, err := gotmpl.New("test").Parse(tc.template)
|
||||||
|
if err != nil && tc.template != "" {
|
||||||
|
t.Fatalf("failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, found := extractToolCallsFormat(tmpl)
|
||||||
|
if got != tc.want {
|
||||||
|
t.Errorf("got text %q, want %q", got, tc.want)
|
||||||
|
}
|
||||||
|
if found != tc.found {
|
||||||
|
t.Errorf("got found %v, want %v", found, tc.found)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestToolPrefix(t *testing.T) {
|
func TestToolPrefix(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
name string
|
name string
|
||||||
template string
|
template string
|
||||||
want string
|
want string
|
||||||
ok bool
|
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "basic tool call with action prefix",
|
name: "basic tool call with action prefix",
|
||||||
template: "{{if .ToolCalls}}Action: ```json{{end}}",
|
template: "{{if .ToolCalls}}Action: ```json{{end}}",
|
||||||
want: "Action:",
|
want: "Action:",
|
||||||
ok: true,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "incomplete functools bracket",
|
name: "incomplete functools bracket",
|
||||||
template: "{{if .ToolCalls}}functools[{{end}}",
|
template: "{{if .ToolCalls}}functools[{{end}}",
|
||||||
want: "functools",
|
want: "functools",
|
||||||
ok: true,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "tool call with angle brackets",
|
name: "tool call with angle brackets",
|
||||||
template: "{{if .ToolCalls}}Hello, world! <tool_call>{{end}}",
|
template: "{{if .ToolCalls}}Hello, world! <tool_call>{{end}}",
|
||||||
want: "<tool_call>",
|
want: "<tool_call>",
|
||||||
ok: true,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple tool call formats",
|
name: "multiple tool call formats",
|
||||||
template: "{{if .ToolCalls}}[tool_call] <tool_call>{{end}}",
|
template: "{{if .ToolCalls}}[tool_call] <tool_call>{{end}}",
|
||||||
want: "[tool_call]",
|
want: "[tool_call]",
|
||||||
ok: true,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "single angle bracket tool call",
|
name: "single angle bracket tool call",
|
||||||
template: "{{if .ToolCalls}}<tool_call>{{end}}",
|
template: "{{if .ToolCalls}}<tool_call>{{end}}",
|
||||||
want: "<tool_call>",
|
want: "<tool_call>",
|
||||||
ok: true,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "incomplete angle bracket after tool call",
|
name: "incomplete angle bracket after tool call",
|
||||||
template: "{{if .ToolCalls}}[tool_call] <{{end}}",
|
template: "{{if .ToolCalls}}[tool_call] <{{end}}",
|
||||||
want: "[tool_call]",
|
want: "[tool_call]",
|
||||||
ok: true,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "angle bracket prefix with tool call",
|
name: "angle bracket prefix with tool call",
|
||||||
template: "{{if .ToolCalls}}> <tool_call>{{end}}",
|
template: "{{if .ToolCalls}}> <tool_call>{{end}}",
|
||||||
want: "<tool_call>",
|
want: "<tool_call>",
|
||||||
ok: true,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "uppercase tool call with incomplete bracket",
|
name: "uppercase tool call with incomplete bracket",
|
||||||
template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}",
|
template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}",
|
||||||
want: "[TOOL_CALL]",
|
want: "[TOOL_CALL]",
|
||||||
ok: true,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "uppercase tool call with adjacent bracket",
|
name: "uppercase tool call with adjacent bracket",
|
||||||
template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}",
|
template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}",
|
||||||
want: "[TOOL_CALL]",
|
want: "[TOOL_CALL]",
|
||||||
ok: true,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "tool call with pipe delimiters",
|
name: "tool call with pipe delimiters",
|
||||||
template: "{{if .ToolCalls}}<|tool_call|>{{end}}",
|
template: "{{if .ToolCalls}}<|tool_call|>{{end}}",
|
||||||
want: "<|tool_call|>",
|
want: "<|tool_call|>",
|
||||||
ok: true,
|
},
|
||||||
|
{
|
||||||
|
name: "tool with no prefix",
|
||||||
|
template: "{{if .ToolCalls}}{{end}}",
|
||||||
|
want: "",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -80,18 +139,135 @@ func TestToolPrefix(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to parse template: %v", err)
|
t.Fatalf("failed to parse template: %v", err)
|
||||||
}
|
}
|
||||||
got, ok := ToolPrefix(tmpl)
|
got := toolPrefix(tmpl)
|
||||||
if got != tt.want {
|
if got != tt.want {
|
||||||
t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want)
|
t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want)
|
||||||
}
|
}
|
||||||
if ok != tt.ok {
|
|
||||||
t.Errorf("ToolToken(%q) = %v; want %v", tt.template, ok, tt.ok)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTextAfterToolCalls(t *testing.T) {
|
func TestToolTemplate(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
template string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "basic tool call range",
|
||||||
|
template: "{{range .ToolCalls}}test{{end}}",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no tool calls",
|
||||||
|
template: "{{range .Other}}test{{end}}",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested tool calls",
|
||||||
|
template: "{{range .Outer}}{{range .ToolCalls}}test{{end}}{{end}}",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty template",
|
||||||
|
template: "",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool calls in if statement",
|
||||||
|
template: "{{if .ToolCalls}}test{{end}}",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tmpl, err := gotmpl.New("test").Parse(tt.template)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed, err := template.Parse(tmpl.Root.String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, got := toolTemplate(parsed)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("toolTemplate() = %v; want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSuffixOverlap(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
s string
|
||||||
|
d string
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no overlap",
|
||||||
|
s: "hello world",
|
||||||
|
d: "",
|
||||||
|
want: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "full overlap",
|
||||||
|
s: "<tool_call>",
|
||||||
|
d: "<tool_call>",
|
||||||
|
want: 11,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "partial overlap",
|
||||||
|
s: "text <tool_call>",
|
||||||
|
d: "<tool_call>",
|
||||||
|
want: 11,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "delimiter longer than string",
|
||||||
|
s: "<tool>",
|
||||||
|
d: "<tool_call>",
|
||||||
|
want: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
s: "",
|
||||||
|
d: "<tool_call>",
|
||||||
|
want: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty delimiter",
|
||||||
|
s: "<tool_call>",
|
||||||
|
d: "",
|
||||||
|
want: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single char overlap",
|
||||||
|
s: "test<",
|
||||||
|
d: "<tool_call>",
|
||||||
|
want: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "partial tool call",
|
||||||
|
s: "hello <tool_",
|
||||||
|
d: "<tool_call>",
|
||||||
|
want: 6,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := suffixOverlap(tt.s, tt.d)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("suffixOverlap(%q, %q) = %d; want %d", tt.s, tt.d, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractToolArgs(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
name string
|
name string
|
||||||
template string
|
template string
|
||||||
@ -173,7 +349,7 @@ func TestTextAfterToolCalls(t *testing.T) {
|
|||||||
t.Fatalf("failed to parse template: %v", err)
|
t.Fatalf("failed to parse template: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
got, ok := extractToolCallsTemplate(tmpl)
|
got, ok := extractToolCallsFormat(tmpl)
|
||||||
if got != tt.want {
|
if got != tt.want {
|
||||||
t.Errorf("TextAfterToolCalls() got = %q, want %q", got, tt.want)
|
t.Errorf("TextAfterToolCalls() got = %q, want %q", got, tt.want)
|
||||||
}
|
}
|
||||||
@ -183,3 +359,106 @@ func TestTextAfterToolCalls(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCollect(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
obj any
|
||||||
|
want []map[string]any
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple map",
|
||||||
|
obj: map[string]any{
|
||||||
|
"key": "value",
|
||||||
|
},
|
||||||
|
want: []map[string]any{
|
||||||
|
{"key": "value"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested map",
|
||||||
|
obj: map[string]any{
|
||||||
|
"outer": map[string]any{
|
||||||
|
"inner": "value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []map[string]any{
|
||||||
|
{"outer": map[string]any{"inner": "value"}},
|
||||||
|
{"inner": "value"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "array of maps",
|
||||||
|
obj: []any{
|
||||||
|
map[string]any{"key1": "val1"},
|
||||||
|
map[string]any{"key2": "val2"},
|
||||||
|
},
|
||||||
|
want: []map[string]any{
|
||||||
|
{"key1": "val1"},
|
||||||
|
{"key2": "val2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deeply nested",
|
||||||
|
obj: map[string]any{
|
||||||
|
"l1": map[string]any{
|
||||||
|
"l2": map[string]any{
|
||||||
|
"l3": "value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []map[string]any{
|
||||||
|
{"l1": map[string]any{"l2": map[string]any{"l3": "value"}}},
|
||||||
|
{"l2": map[string]any{"l3": "value"}},
|
||||||
|
{"l3": "value"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-map value",
|
||||||
|
obj: "string",
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := collect(tt.obj)
|
||||||
|
if len(got) != len(tt.want) {
|
||||||
|
t.Errorf("collect() got %d maps, want %d", len(got), len(tt.want))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compare each map in the result
|
||||||
|
for i := range tt.want {
|
||||||
|
if !mapsEqual(got[i], tt.want[i]) {
|
||||||
|
t.Errorf("collect() map[%d] = %v, want %v", i, got[i], tt.want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapsEqual compares two maps for deep equality
|
||||||
|
func mapsEqual(m1, m2 map[string]any) bool {
|
||||||
|
if len(m1) != len(m2) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for k, v1 := range m1 {
|
||||||
|
v2, ok := m2[k]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch val1 := v1.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
val2, ok := v2.(map[string]any)
|
||||||
|
if !ok || !mapsEqual(val1, val2) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if v1 != v2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user