add tests, organize, comments
This commit is contained in:
parent
bc83789be9
commit
8ed95a4e96
@ -1529,9 +1529,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
if err == nil {
|
||||
if len(content) > 0 {
|
||||
res.Message.Content = content
|
||||
fmt.Println("sending content in response", content)
|
||||
slog.Debug("tools: setting content to", "content", content)
|
||||
} else if len(toolCalls) > 0 {
|
||||
fmt.Println("sending tool calls in response", toolCalls)
|
||||
res.Message.ToolCalls = toolCalls
|
||||
res.Message.Content = ""
|
||||
} else {
|
||||
|
280
tools/tools.go
280
tools/tools.go
@ -1,9 +1,7 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"strings"
|
||||
@ -16,136 +14,56 @@ import (
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
// TODO: simplify if possible
|
||||
type Parser struct {
|
||||
greedy bool
|
||||
greedyParse bool
|
||||
prefixFound bool
|
||||
partialPrefix bool
|
||||
prefixPartial bool
|
||||
tmpl *gotmpl.Template
|
||||
sb *strings.Builder
|
||||
prefix string
|
||||
index int
|
||||
name string
|
||||
arguments string
|
||||
Done bool
|
||||
}
|
||||
|
||||
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
||||
// Returns parsed tool calls, a boolean indicating if the JSON is incomplete, and a boolean indicating if the tool calls were found
|
||||
func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
||||
fmt.Printf("attempting to parse JSON tool calls: input=%s\n", s)
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := p.tmpl.Execute(&b, map[string][]api.ToolCall{
|
||||
"ToolCalls": {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "@@name@@",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"@@argument@@": 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
fmt.Printf("failed to execute template: error=%v\n", err)
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
// this can be either a map or an array
|
||||
var temp any
|
||||
err := jsonv2.Unmarshal(b.Bytes(), &temp)
|
||||
if err != nil {
|
||||
fmt.Printf("failed to unmarshal template: error=%v\n", err)
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
var collect func(any) []map[string]any
|
||||
collect = func(obj any) (all []map[string]any) {
|
||||
switch o := obj.(type) {
|
||||
case map[string]any:
|
||||
all = append(all, o)
|
||||
for _, v := range o {
|
||||
all = append(all, collect(v)...)
|
||||
}
|
||||
case []any:
|
||||
for _, v := range o {
|
||||
all = append(all, collect(v)...)
|
||||
}
|
||||
default:
|
||||
// TODO: err or fallback
|
||||
fmt.Printf("collect encountered unknown type: type=%T\n", obj)
|
||||
return nil
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
var templateObjects []map[string]any
|
||||
switch t := temp.(type) {
|
||||
case map[string]any:
|
||||
templateObjects = []map[string]any{t}
|
||||
case []map[string]any:
|
||||
templateObjects = t
|
||||
// ! fallback?
|
||||
case []any:
|
||||
templateObjects = collect(t)
|
||||
}
|
||||
if len(templateObjects) == 0 {
|
||||
fmt.Println("no template objects found")
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
// find the keys that correspond to the name and arguments fields
|
||||
var name, arguments string
|
||||
for k, v := range templateObjects[0] {
|
||||
switch v.(type) {
|
||||
case string:
|
||||
name = k
|
||||
fmt.Printf("found name field: key=%s\n", k)
|
||||
case map[string]any:
|
||||
arguments = k
|
||||
fmt.Printf("found arguments field: key=%s\n", k)
|
||||
}
|
||||
}
|
||||
|
||||
if name == "" || arguments == "" {
|
||||
fmt.Printf("missing required fields: name_found=%v arguments_found=%v\n", name != "", arguments != "")
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
// TODO: there is probably some underlying repeat work here to avoid
|
||||
// This incrementally decodes the JSON string and returns the first parsedobject
|
||||
// It first tries to incrementally decode the JSON to handle partial inputs.
|
||||
// Returns:
|
||||
// - []api.ToolCall: The parsed tool calls if successful
|
||||
// - bool: True if JSON is incomplete and needs more input
|
||||
func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool) {
|
||||
// First try incremental decoding to handle partial JSON
|
||||
dec := jsontext.NewDecoder(strings.NewReader(s))
|
||||
if got, err := dec.ReadValue(); err == nil {
|
||||
s = got.String()
|
||||
fmt.Printf("decoded JSON value: value=%s\n", s)
|
||||
}
|
||||
|
||||
var responseObjects any
|
||||
err = jsonv2.Unmarshal([]byte(s), &responseObjects)
|
||||
// Attempt full unmarshal of the JSON
|
||||
var resp any
|
||||
err := jsonv2.Unmarshal([]byte(s), &resp)
|
||||
if err != nil {
|
||||
// Handle incomplete JSON cases
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) || err.Error() == "unexpected end of JSON input" {
|
||||
fmt.Println("incomplete JSON detected")
|
||||
return nil, true, false
|
||||
} else {
|
||||
fmt.Printf("failed to unmarshal response: error=%v\n", err)
|
||||
return nil, false, false
|
||||
slog.Debug("incomplete JSON detected", "input", s)
|
||||
return nil, true
|
||||
}
|
||||
slog.Debug("failed to unmarshal response", "error", err)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Collect all nested objects that could contain tool calls
|
||||
var objs []map[string]any
|
||||
objs = append(objs, collect(responseObjects)...)
|
||||
objs = append(objs, collect(resp)...)
|
||||
if len(objs) == 0 {
|
||||
return nil, false, false
|
||||
return nil, false
|
||||
}
|
||||
|
||||
fmt.Printf("collected objects: count=%d\n", len(objs))
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
for _, kv := range objs {
|
||||
n, nok := kv[name].(string)
|
||||
a, aok := kv[arguments].(map[string]any)
|
||||
n, nok := kv[p.name].(string)
|
||||
a, aok := kv[p.arguments].(map[string]any)
|
||||
if nok && aok {
|
||||
fmt.Printf("found valid tool call: name=%s\n", n)
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: n,
|
||||
@ -155,130 +73,170 @@ func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("parsed tool calls: count=%d\n", len(toolCalls))
|
||||
return toolCalls, false, true
|
||||
// Valid JSON, no tool calls found
|
||||
if len(toolCalls) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// prefix stripped string if any, prefix found, and if we should accumulate
|
||||
func (p *Parser) checkPrefix(s string) (string, bool, bool) {
|
||||
return toolCalls, false
|
||||
}
|
||||
|
||||
// checkPrefix processes a string to find and handle a prefix pattern.
|
||||
//
|
||||
// Returns:
|
||||
// - The processed string with prefix removed if found
|
||||
// - Whether the prefix was found at the start of the string
|
||||
// - Whether to continue parsing
|
||||
func (p *Parser) checkPrefix(s string) (string, bool, bool) {
|
||||
// Keep original for overlap checks
|
||||
original := s
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return "", false, true
|
||||
}
|
||||
|
||||
// If no prefix defined, just return trimmed string
|
||||
if p.prefix == "" {
|
||||
return s, false, true
|
||||
}
|
||||
original := s
|
||||
s = strings.TrimSpace(s)
|
||||
s, hasPrefix := strings.CutPrefix(s, p.prefix)
|
||||
if hasPrefix {
|
||||
// partial tool possibly - accumulate
|
||||
return s, true, true
|
||||
} else if overlap := suffixOverlap(original, p.prefix); overlap > 0 {
|
||||
// p.state = PartialPrefix
|
||||
p.partialPrefix = true
|
||||
return original[0 : len(original)-overlap], false, false
|
||||
} else if idx := strings.Index(original, p.prefix); idx != -1 {
|
||||
// Found prefix in middle of string, keep only content before prefix
|
||||
// accounts for spaces in prefix or suffix to avoid breaking cache
|
||||
p.partialPrefix = true
|
||||
p.sb.Reset()
|
||||
|
||||
// Check for prefix at start of string
|
||||
if processedStr, hasPrefix := strings.CutPrefix(s, p.prefix); hasPrefix {
|
||||
// Found prefix at start - accumulate for potential tool
|
||||
return processedStr, true, true
|
||||
}
|
||||
|
||||
// Check if prefix overlaps end of string
|
||||
if overlap := suffixOverlap(original, p.prefix); overlap > 0 {
|
||||
p.prefixPartial = true
|
||||
// Return everything except overlapping portion
|
||||
p.sb.Reset()
|
||||
p.sb.WriteString(original[len(original)-overlap:])
|
||||
return original[0 : len(original)-overlap], false, false
|
||||
}
|
||||
|
||||
// Check if prefix appears in middle of string
|
||||
if idx := strings.Index(original, p.prefix); idx != -1 {
|
||||
p.prefixPartial = true
|
||||
// Save remainder starting at prefix for next pass
|
||||
p.sb.Reset()
|
||||
p.sb.WriteString(strings.TrimSpace(original[idx:]))
|
||||
// Return everything before prefix
|
||||
return original[:idx], false, false
|
||||
}
|
||||
|
||||
p.partialPrefix = false
|
||||
// No prefix found
|
||||
p.prefixPartial = false
|
||||
return s, false, true
|
||||
}
|
||||
|
||||
// Add processes a string input to parse tool calls and content.
|
||||
// It handles prefix detection and JSON parsing to extract tool calls.
|
||||
//
|
||||
// Returns:
|
||||
// - tools: Any parsed tool calls
|
||||
// - content: Non-tool call content
|
||||
// - err: Error if parsing failed
|
||||
func (p *Parser) Add(s string) (tools []api.ToolCall, content string, err error) {
|
||||
slog.Debug("adding tool calls", "input", s)
|
||||
|
||||
p.sb.WriteString(s)
|
||||
s = p.sb.String()
|
||||
|
||||
if len(s) == 0 {
|
||||
return nil, "", nil
|
||||
}
|
||||
|
||||
s, prefixFound, cont := p.checkPrefix(s)
|
||||
|
||||
if !cont {
|
||||
// Check for prefix pattern in input
|
||||
s, prefixFound, shouldContinue := p.checkPrefix(s)
|
||||
if !shouldContinue {
|
||||
if s != "" {
|
||||
// send only the content back, prefix exists
|
||||
// Return content before prefix
|
||||
return nil, s, nil
|
||||
}
|
||||
// accumulate case
|
||||
// Need more input to complete prefix
|
||||
return nil, "", nil
|
||||
}
|
||||
|
||||
// circuit breaker
|
||||
// Update prefix found state
|
||||
if prefixFound {
|
||||
p.prefixFound = true
|
||||
}
|
||||
|
||||
// for cases with a prefix in template
|
||||
if p.prefix != "" && !p.greedy && !p.prefixFound {
|
||||
// send tokens down
|
||||
// Exit if prefix exists in template, greedy parsing is off, and prefix not found
|
||||
if !p.greedyParse && !p.prefixFound {
|
||||
p.sb.Reset()
|
||||
return nil, "", errors.New("prefix not found")
|
||||
}
|
||||
// we have a prefix or are in json mode
|
||||
tcs, partial, ok := p.parseJSONToolCalls(s)
|
||||
if partial {
|
||||
// accumulate case
|
||||
|
||||
toolCalls, isPartial := p.parseJSONToolCalls(s)
|
||||
if isPartial {
|
||||
// Need more input to complete JSON
|
||||
return nil, "", nil
|
||||
}
|
||||
|
||||
p.greedy = false
|
||||
if !ok {
|
||||
// will not be a partial at this point
|
||||
// Do not try greedy parsing if partial JSON not found
|
||||
p.greedyParse = false
|
||||
|
||||
// Handle invalid tool call format
|
||||
if len(toolCalls) == 0 {
|
||||
p.sb.Reset()
|
||||
// send tokens
|
||||
if p.prefix == "" {
|
||||
p.Done = true
|
||||
}
|
||||
if p.prefixFound {
|
||||
// drop tokens instead - sb is reset, no tokens sent to user
|
||||
// Drop tokens since prefix was found
|
||||
return nil, "", nil
|
||||
}
|
||||
return nil, "", errors.New("failed to parse tool calls")
|
||||
return nil, s, nil
|
||||
}
|
||||
|
||||
for _, tc := range tcs {
|
||||
for _, tc := range toolCalls {
|
||||
tc.Function.Index = p.index
|
||||
p.index++
|
||||
}
|
||||
|
||||
// Mark as done if no prefix needed
|
||||
if p.prefix == "" {
|
||||
p.Done = true
|
||||
}
|
||||
|
||||
p.sb.Reset()
|
||||
return tcs, "", nil
|
||||
return toolCalls, "", nil
|
||||
}
|
||||
|
||||
// NewParser creates a new tool call parser from a template. It extracts the tool call format,
|
||||
// prefix, and field names from the template to use for parsing tool calls from model output.
|
||||
//
|
||||
// Returns an error if the template does not contain valid tool call formatting.
|
||||
func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) {
|
||||
parsedTemplate, err := template.Parse(templateToProcess.Root.String())
|
||||
parsed, err := template.Parse(templateToProcess.Root.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsedTemplate == nil {
|
||||
if parsed == nil {
|
||||
return nil, errors.New("failed to parse template")
|
||||
}
|
||||
|
||||
toolCallTemplate, hasToolCalls := toolTemplate(parsedTemplate)
|
||||
if !hasToolCalls {
|
||||
return nil, errors.New("failed to find tool template")
|
||||
tt, tc := toolTemplate(parsed)
|
||||
if !tc {
|
||||
return nil, errors.New("failed to find tool calls in template")
|
||||
}
|
||||
if toolCallTemplate == nil {
|
||||
if tt == nil {
|
||||
return nil, errors.New("failed to find tool template")
|
||||
}
|
||||
|
||||
toolPrefix, _ := ToolPrefix(templateToProcess)
|
||||
toolPrefix = strings.TrimSpace(toolPrefix)
|
||||
tp := toolPrefix(templateToProcess)
|
||||
tp = strings.TrimSpace(tp)
|
||||
|
||||
name, arguments, err := extractToolArgs(tt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fmt.Printf("creating new tool parser: prefix=%s\n", toolPrefix)
|
||||
return &Parser{
|
||||
tmpl: toolCallTemplate,
|
||||
tmpl: tt,
|
||||
sb: &strings.Builder{},
|
||||
prefix: toolPrefix,
|
||||
greedy: true,
|
||||
prefix: tp,
|
||||
greedyParse: true,
|
||||
name: name,
|
||||
arguments: arguments,
|
||||
}, nil
|
||||
}
|
||||
|
@ -6,11 +6,8 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
gotmpl "text/template"
|
||||
"text/template/parse"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
@ -206,6 +203,27 @@ func TestParseToolCalls(t *testing.T) {
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 tool calls without prefix and valid tool call",
|
||||
model: "qwen2.5-coder",
|
||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 tool calls without prefix and invalid tool call",
|
||||
model: "qwen2.5-coder",
|
||||
output: `[{"options": "foo"}]`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `[{"options": "foo"}]`,
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 tool calls with prefix and invalid tool call",
|
||||
model: "qwen2.5-coder",
|
||||
output: `<tool_call> [{"options": "foo"}] </tool_call> `,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: ``,
|
||||
},
|
||||
{
|
||||
name: "qwen3 tool call with think prefix and tool prefix (sent as a single token)",
|
||||
model: "qwen3",
|
||||
@ -239,14 +257,14 @@ func TestParseToolCalls(t *testing.T) {
|
||||
model: "qwen3",
|
||||
output: `<think></think>< fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `<think></think> fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
expectedTokens: `<think></think>< fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
},
|
||||
{
|
||||
name: "qwen3 invalid tool call with partial tool prefix (multiple rune suffix match)",
|
||||
model: "qwen3",
|
||||
output: `<think></think><tool_c fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `<think></think> fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
expectedTokens: `<think></think><tool_c fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
},
|
||||
{
|
||||
name: "qwen3 invalid tool call with malformed tool prefix",
|
||||
@ -332,6 +350,7 @@ func TestParseToolCalls(t *testing.T) {
|
||||
toolCalls, content, err := tp.Add(s)
|
||||
if err == nil {
|
||||
if content != "" {
|
||||
fmt.Printf("content: %q\n", content)
|
||||
gotTokens.WriteString(content)
|
||||
add = false
|
||||
} else if len(toolCalls) > 0 {
|
||||
@ -363,24 +382,101 @@ func TestParseToolCalls(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func toolTemplateHelper(t *testing.T, tmpl *template.Template) (*gotmpl.Template, bool) {
|
||||
// create a subtree from the node that ranges over .ToolCalls
|
||||
|
||||
tmpl2 := tmpl.Subtree(func(n parse.Node) bool {
|
||||
if t, ok := n.(*parse.RangeNode); ok {
|
||||
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
||||
func TestParseJSONToolCalls(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
parser *Parser
|
||||
wantToolCalls []api.ToolCall
|
||||
wantPartial bool
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "valid single tool call",
|
||||
input: `{"name": "test_tool", "arguments": {"arg1": "value1"}}`,
|
||||
parser: &Parser{name: "name", arguments: "arguments"},
|
||||
wantToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test_tool",
|
||||
Arguments: map[string]any{
|
||||
"arg1": "value1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantPartial: false,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "incomplete JSON",
|
||||
input: `{"name": "test_tool", "arguments": {"arg1": `,
|
||||
parser: &Parser{name: "name", arguments: "arguments"},
|
||||
wantToolCalls: nil,
|
||||
wantPartial: true,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
input: `not json at all`,
|
||||
parser: &Parser{name: "name", arguments: "arguments"},
|
||||
wantToolCalls: nil,
|
||||
wantPartial: false,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "missing required fields",
|
||||
input: `{"other": "field"}`,
|
||||
parser: &Parser{name: "name", arguments: "arguments"},
|
||||
wantToolCalls: nil,
|
||||
wantPartial: false,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls in array",
|
||||
input: `[
|
||||
{"name": "tool1", "arguments": {"arg1": 1}},
|
||||
{"name": "tool2", "arguments": {"arg2": "value"}}
|
||||
]`,
|
||||
parser: &Parser{name: "name", arguments: "arguments"},
|
||||
wantToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "tool1",
|
||||
Arguments: map[string]any{
|
||||
"arg1": float64(1),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "tool2",
|
||||
Arguments: map[string]any{
|
||||
"arg2": "value",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantPartial: false,
|
||||
wantValid: true,
|
||||
},
|
||||
}
|
||||
|
||||
return false
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotCalls, gotPartial := tt.parser.parseJSONToolCalls(tt.input)
|
||||
|
||||
if gotPartial != tt.wantPartial {
|
||||
t.Errorf("parseJSONToolCalls() partial = %v, want %v", gotPartial, tt.wantPartial)
|
||||
}
|
||||
|
||||
if len(gotCalls) != 0 != tt.wantValid {
|
||||
t.Errorf("parseJSONToolCalls() valid = %v, want %v", len(gotCalls) == 0, tt.wantValid)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(gotCalls, tt.wantToolCalls); diff != "" {
|
||||
t.Errorf("parseJSONToolCalls() tool calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
if tmpl2.Root != nil {
|
||||
t.Log("tmpl2", tmpl2.Root.String())
|
||||
}
|
||||
|
||||
if tmpl2 == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return tmpl2, true
|
||||
}
|
||||
|
132
tools/utils.go
132
tools/utils.go
@ -1,17 +1,27 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
gotmpl "text/template"
|
||||
"text/template/parse"
|
||||
|
||||
jsonv2 "github.com/go-json-experiment/json"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
// extractToolCallsTemplate finds the immediate following text after any IfNode containing ".ToolCalls"
|
||||
func extractToolCallsTemplate(tmpl *gotmpl.Template) (string, bool) {
|
||||
// extractToolCallsFormat traverses a template AST to find text that follows a ".ToolCalls" condition.
|
||||
// It walks the template nodes looking for if-statements containing ".ToolCalls" and extracts any
|
||||
// immediate text nodes that follow. This is used to identify tool call prefixes and formatting.
|
||||
//
|
||||
// Returns:
|
||||
// - string: The extracted text following the first ".ToolCalls" condition found
|
||||
// - bool: Whether a ".ToolCalls" condition was found in the template
|
||||
func extractToolCallsFormat(tmpl *gotmpl.Template) (string, bool) {
|
||||
if tmpl == nil || tmpl.Tree == nil {
|
||||
slog.Debug("TextAfterToolCalls: template or tree is nil")
|
||||
return "", false
|
||||
@ -29,7 +39,7 @@ func extractToolCallsTemplate(tmpl *gotmpl.Template) (string, bool) {
|
||||
|
||||
switch n := node.(type) {
|
||||
case *parse.IfNode:
|
||||
if nodeContainsToolCalls(n) {
|
||||
if isToolCallsNode(n) {
|
||||
// Collect immediate TextNode(s) at start of IfNode's list
|
||||
var sb strings.Builder
|
||||
for _, innerNode := range n.List.Nodes {
|
||||
@ -76,8 +86,8 @@ func extractToolCallsTemplate(tmpl *gotmpl.Template) (string, bool) {
|
||||
return result, found
|
||||
}
|
||||
|
||||
// Helper to detect if a node's condition includes ".ToolCalls"
|
||||
func nodeContainsToolCalls(n *parse.IfNode) bool {
|
||||
// isToolCallsNode detects if a node's condition includes ".ToolCalls"
|
||||
func isToolCallsNode(n *parse.IfNode) bool {
|
||||
for _, cmd := range n.Pipe.Cmds {
|
||||
for _, arg := range cmd.Args {
|
||||
if field, ok := arg.(*parse.FieldNode); ok {
|
||||
@ -90,16 +100,17 @@ func nodeContainsToolCalls(n *parse.IfNode) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// ToolPrefix returns the prefix for the tool call if it exists
|
||||
// TODO(parthsareen): get full prefix from the template instead of just the first token
|
||||
func ToolPrefix(tmpl *gotmpl.Template) (string, bool) {
|
||||
tokenText, ok := extractToolCallsTemplate(tmpl)
|
||||
|
||||
// toolPrefix returns the prefix for the tool call if it exists from a template
|
||||
func toolPrefix(tmpl *gotmpl.Template) string {
|
||||
tokenText, ok := extractToolCallsFormat(tmpl)
|
||||
if !ok {
|
||||
return "", false
|
||||
return ""
|
||||
}
|
||||
tokenText = strings.TrimSpace(tokenText)
|
||||
if tokenText == "" {
|
||||
return "", false
|
||||
return ""
|
||||
}
|
||||
first := strings.Fields(tokenText)[0]
|
||||
|
||||
@ -116,19 +127,23 @@ func ToolPrefix(tmpl *gotmpl.Template) (string, bool) {
|
||||
}
|
||||
if start != -1 && end != -1 {
|
||||
// return the token including the [ or < and the ] or >
|
||||
return tokenText[start : end+1], true
|
||||
return tokenText[start : end+1]
|
||||
} else if start != -1 {
|
||||
// get until the [ or < - in the case tag was not closed
|
||||
return tokenText[:start], true
|
||||
return tokenText[:start]
|
||||
} else if end != -1 {
|
||||
// get after the ] or > - in the case tag was not opened
|
||||
return tokenText[end+1:], true
|
||||
return tokenText[end+1:]
|
||||
}
|
||||
return first, true
|
||||
return first
|
||||
}
|
||||
|
||||
// toolTemplate creates a subtree from the node that ranges over .ToolCalls
|
||||
//
|
||||
// Returns:
|
||||
// - *gotmpl.Template: The subtree containing the .ToolCalls range
|
||||
// - bool: Whether a .ToolCalls range was found in the template
|
||||
func toolTemplate(t *template.Template) (*gotmpl.Template, bool) {
|
||||
// create a subtree from the node that ranges over .ToolCalls
|
||||
tmpl := t.Subtree(func(n parse.Node) bool {
|
||||
if t, ok := n.(*parse.RangeNode); ok {
|
||||
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
||||
@ -144,6 +159,10 @@ func toolTemplate(t *template.Template) (*gotmpl.Template, bool) {
|
||||
return tmpl, true
|
||||
}
|
||||
|
||||
// suffixOverlap returns the length of the longest suffix overlap between two strings
|
||||
//
|
||||
// Returns:
|
||||
// - int: The length of the longest suffix overlap
|
||||
func suffixOverlap(s, delim string) int {
|
||||
max := min(len(delim), len(s))
|
||||
for i := max; i > 0; i-- {
|
||||
@ -153,3 +172,86 @@ func suffixOverlap(s, delim string) int {
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// extractToolArgs executes a template with a known tool call format to extract the name and arguments
|
||||
//
|
||||
// Returns:
|
||||
// - string: The name of the tool call
|
||||
// - string: The arguments of the tool call
|
||||
// - error: Error if parsing failed
|
||||
func extractToolArgs(tmpl *gotmpl.Template) (name, arguments string, err error) {
|
||||
var b bytes.Buffer
|
||||
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
|
||||
"ToolCalls": {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "@@name@@",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"@@argument@@": 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
var obj any
|
||||
err = jsonv2.Unmarshal(b.Bytes(), &obj)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
var objs []map[string]any
|
||||
switch v := obj.(type) {
|
||||
case map[string]any:
|
||||
objs = []map[string]any{v}
|
||||
case []map[string]any:
|
||||
objs = v
|
||||
case []any:
|
||||
objs = collect(v)
|
||||
}
|
||||
if len(objs) == 0 {
|
||||
return "", "", errors.New("no template objects found")
|
||||
}
|
||||
|
||||
// find the keys that correspond to the name and arguments fields
|
||||
for k, v := range objs[0] {
|
||||
switch v.(type) {
|
||||
case string:
|
||||
name = k
|
||||
case map[string]any:
|
||||
arguments = k
|
||||
}
|
||||
}
|
||||
|
||||
if name == "" || arguments == "" {
|
||||
slog.Debug("missing required fields in tool call template", "name", name, "arguments", arguments)
|
||||
return "", "", errors.New("missing required fields in tool call template")
|
||||
}
|
||||
|
||||
return name, arguments, nil
|
||||
}
|
||||
|
||||
// collect recursively traverses an object to collect all nested maps
|
||||
//
|
||||
// Returns:
|
||||
// - []map[string]any: A slice of all nested maps found in the object
|
||||
func collect(obj any) []map[string]any {
|
||||
var all []map[string]any
|
||||
switch o := obj.(type) {
|
||||
case map[string]any:
|
||||
all = append(all, o)
|
||||
for _, v := range o {
|
||||
all = append(all, collect(v)...)
|
||||
}
|
||||
case []any:
|
||||
for _, v := range o {
|
||||
all = append(all, collect(v)...)
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
@ -3,74 +3,133 @@ package tools
|
||||
import (
|
||||
"testing"
|
||||
gotmpl "text/template"
|
||||
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
func TestExtractToolCallsFormat(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
template string
|
||||
want string
|
||||
found bool
|
||||
}{
|
||||
{
|
||||
name: "nil template",
|
||||
template: "",
|
||||
want: "",
|
||||
found: false,
|
||||
},
|
||||
{
|
||||
name: "basic tool call with text",
|
||||
template: "{{if .ToolCalls}}Hello world{{end}}",
|
||||
want: "Hello world",
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "tool call with json format",
|
||||
template: "{{if .ToolCalls}}```json\n{{end}}",
|
||||
want: "```json\n",
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "tool call in range",
|
||||
template: "{{range .ToolCalls}}tool: {{.}}{{end}}",
|
||||
want: "",
|
||||
found: false,
|
||||
},
|
||||
{
|
||||
name: "tool call with multiple text nodes",
|
||||
template: "{{if .ToolCalls}}First text{{if .Something}}inner{{end}}Second text{{end}}",
|
||||
want: "First text",
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "nested if without tool calls",
|
||||
template: "{{if .Something}}{{if .OtherThing}}text{{end}}{{end}}",
|
||||
want: "",
|
||||
found: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := gotmpl.New("test").Parse(tc.template)
|
||||
if err != nil && tc.template != "" {
|
||||
t.Fatalf("failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
got, found := extractToolCallsFormat(tmpl)
|
||||
if got != tc.want {
|
||||
t.Errorf("got text %q, want %q", got, tc.want)
|
||||
}
|
||||
if found != tc.found {
|
||||
t.Errorf("got found %v, want %v", found, tc.found)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolPrefix(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
template string
|
||||
want string
|
||||
ok bool
|
||||
}{
|
||||
{
|
||||
name: "basic tool call with action prefix",
|
||||
template: "{{if .ToolCalls}}Action: ```json{{end}}",
|
||||
want: "Action:",
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "incomplete functools bracket",
|
||||
template: "{{if .ToolCalls}}functools[{{end}}",
|
||||
want: "functools",
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "tool call with angle brackets",
|
||||
template: "{{if .ToolCalls}}Hello, world! <tool_call>{{end}}",
|
||||
want: "<tool_call>",
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "multiple tool call formats",
|
||||
template: "{{if .ToolCalls}}[tool_call] <tool_call>{{end}}",
|
||||
want: "[tool_call]",
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "single angle bracket tool call",
|
||||
template: "{{if .ToolCalls}}<tool_call>{{end}}",
|
||||
want: "<tool_call>",
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "incomplete angle bracket after tool call",
|
||||
template: "{{if .ToolCalls}}[tool_call] <{{end}}",
|
||||
want: "[tool_call]",
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "angle bracket prefix with tool call",
|
||||
template: "{{if .ToolCalls}}> <tool_call>{{end}}",
|
||||
want: "<tool_call>",
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "uppercase tool call with incomplete bracket",
|
||||
template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}",
|
||||
want: "[TOOL_CALL]",
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "uppercase tool call with adjacent bracket",
|
||||
template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}",
|
||||
want: "[TOOL_CALL]",
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "tool call with pipe delimiters",
|
||||
template: "{{if .ToolCalls}}<|tool_call|>{{end}}",
|
||||
want: "<|tool_call|>",
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "tool with no prefix",
|
||||
template: "{{if .ToolCalls}}{{end}}",
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
@ -80,18 +139,135 @@ func TestToolPrefix(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse template: %v", err)
|
||||
}
|
||||
got, ok := ToolPrefix(tmpl)
|
||||
got := toolPrefix(tmpl)
|
||||
if got != tt.want {
|
||||
t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want)
|
||||
}
|
||||
if ok != tt.ok {
|
||||
t.Errorf("ToolToken(%q) = %v; want %v", tt.template, ok, tt.ok)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextAfterToolCalls(t *testing.T) {
|
||||
func TestToolTemplate(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
template string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "basic tool call range",
|
||||
template: "{{range .ToolCalls}}test{{end}}",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no tool calls",
|
||||
template: "{{range .Other}}test{{end}}",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "nested tool calls",
|
||||
template: "{{range .Outer}}{{range .ToolCalls}}test{{end}}{{end}}",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "empty template",
|
||||
template: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "tool calls in if statement",
|
||||
template: "{{if .ToolCalls}}test{{end}}",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpl, err := gotmpl.New("test").Parse(tt.template)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
parsed, err := template.Parse(tmpl.Root.String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
_, got := toolTemplate(parsed)
|
||||
if got != tt.want {
|
||||
t.Errorf("toolTemplate() = %v; want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSuffixOverlap(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
s string
|
||||
d string
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "no overlap",
|
||||
s: "hello world",
|
||||
d: "",
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "full overlap",
|
||||
s: "<tool_call>",
|
||||
d: "<tool_call>",
|
||||
want: 11,
|
||||
},
|
||||
{
|
||||
name: "partial overlap",
|
||||
s: "text <tool_call>",
|
||||
d: "<tool_call>",
|
||||
want: 11,
|
||||
},
|
||||
{
|
||||
name: "delimiter longer than string",
|
||||
s: "<tool>",
|
||||
d: "<tool_call>",
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
s: "",
|
||||
d: "<tool_call>",
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "empty delimiter",
|
||||
s: "<tool_call>",
|
||||
d: "",
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "single char overlap",
|
||||
s: "test<",
|
||||
d: "<tool_call>",
|
||||
want: 1,
|
||||
},
|
||||
{
|
||||
name: "partial tool call",
|
||||
s: "hello <tool_",
|
||||
d: "<tool_call>",
|
||||
want: 6,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := suffixOverlap(tt.s, tt.d)
|
||||
if got != tt.want {
|
||||
t.Errorf("suffixOverlap(%q, %q) = %d; want %d", tt.s, tt.d, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractToolArgs(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
template string
|
||||
@ -173,7 +349,7 @@ func TestTextAfterToolCalls(t *testing.T) {
|
||||
t.Fatalf("failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
got, ok := extractToolCallsTemplate(tmpl)
|
||||
got, ok := extractToolCallsFormat(tmpl)
|
||||
if got != tt.want {
|
||||
t.Errorf("TextAfterToolCalls() got = %q, want %q", got, tt.want)
|
||||
}
|
||||
@ -183,3 +359,106 @@ func TestTextAfterToolCalls(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollect(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
obj any
|
||||
want []map[string]any
|
||||
}{
|
||||
{
|
||||
name: "simple map",
|
||||
obj: map[string]any{
|
||||
"key": "value",
|
||||
},
|
||||
want: []map[string]any{
|
||||
{"key": "value"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nested map",
|
||||
obj: map[string]any{
|
||||
"outer": map[string]any{
|
||||
"inner": "value",
|
||||
},
|
||||
},
|
||||
want: []map[string]any{
|
||||
{"outer": map[string]any{"inner": "value"}},
|
||||
{"inner": "value"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "array of maps",
|
||||
obj: []any{
|
||||
map[string]any{"key1": "val1"},
|
||||
map[string]any{"key2": "val2"},
|
||||
},
|
||||
want: []map[string]any{
|
||||
{"key1": "val1"},
|
||||
{"key2": "val2"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deeply nested",
|
||||
obj: map[string]any{
|
||||
"l1": map[string]any{
|
||||
"l2": map[string]any{
|
||||
"l3": "value",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []map[string]any{
|
||||
{"l1": map[string]any{"l2": map[string]any{"l3": "value"}}},
|
||||
{"l2": map[string]any{"l3": "value"}},
|
||||
{"l3": "value"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non-map value",
|
||||
obj: "string",
|
||||
want: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := collect(tt.obj)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Errorf("collect() got %d maps, want %d", len(got), len(tt.want))
|
||||
return
|
||||
}
|
||||
|
||||
// Compare each map in the result
|
||||
for i := range tt.want {
|
||||
if !mapsEqual(got[i], tt.want[i]) {
|
||||
t.Errorf("collect() map[%d] = %v, want %v", i, got[i], tt.want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mapsEqual compares two maps for deep equality
|
||||
func mapsEqual(m1, m2 map[string]any) bool {
|
||||
if len(m1) != len(m2) {
|
||||
return false
|
||||
}
|
||||
for k, v1 := range m1 {
|
||||
v2, ok := m2[k]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
switch val1 := v1.(type) {
|
||||
case map[string]any:
|
||||
val2, ok := v2.(map[string]any)
|
||||
if !ok || !mapsEqual(val1, val2) {
|
||||
return false
|
||||
}
|
||||
default:
|
||||
if v1 != v2 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user