tools package and utils
This commit is contained in:
parent
4059b8db01
commit
bc83789be9
@ -10,9 +10,6 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"slices"
|
|
||||||
gotmpl "text/template"
|
|
||||||
"text/template/parse"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
@ -129,19 +126,19 @@ func detectContentType(r io.Reader) (string, error) {
|
|||||||
return "unknown", nil
|
return "unknown", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ToolTemplate(m *Model) (*gotmpl.Template, bool) {
|
// func ToolTemplate(m *Model) (*gotmpl.Template, bool) {
|
||||||
// create a subtree from the node that ranges over .ToolCalls
|
// // create a subtree from the node that ranges over .ToolCalls
|
||||||
tmpl := m.Template.Subtree(func(n parse.Node) bool {
|
// tmpl := m.Template.Subtree(func(n parse.Node) bool {
|
||||||
if t, ok := n.(*parse.RangeNode); ok {
|
// if t, ok := n.(*parse.RangeNode); ok {
|
||||||
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
// return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
||||||
}
|
// }
|
||||||
|
|
||||||
return false
|
// return false
|
||||||
})
|
// })
|
||||||
|
|
||||||
if tmpl == nil {
|
// if tmpl == nil {
|
||||||
return nil, false
|
// return nil, false
|
||||||
}
|
// }
|
||||||
|
|
||||||
return tmpl, true
|
// return tmpl, true
|
||||||
}
|
// }
|
||||||
|
@ -1,185 +1,185 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
// import (
|
||||||
"testing"
|
// "testing"
|
||||||
gotmpl "text/template"
|
// gotmpl "text/template"
|
||||||
)
|
// )
|
||||||
|
|
||||||
func TestToolToken(t *testing.T) {
|
// func TestToolToken(t *testing.T) {
|
||||||
cases := []struct {
|
// cases := []struct {
|
||||||
name string
|
// name string
|
||||||
template string
|
// template string
|
||||||
want string
|
// want string
|
||||||
ok bool
|
// ok bool
|
||||||
}{
|
// }{
|
||||||
{
|
// {
|
||||||
name: "basic tool call with action prefix",
|
// name: "basic tool call with action prefix",
|
||||||
template: "{{if .ToolCalls}}Action: ```json{{end}}",
|
// template: "{{if .ToolCalls}}Action: ```json{{end}}",
|
||||||
want: "Action:",
|
// want: "Action:",
|
||||||
ok: true,
|
// ok: true,
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
name: "incomplete functools bracket",
|
// name: "incomplete functools bracket",
|
||||||
template: "{{if .ToolCalls}}functools[{{end}}",
|
// template: "{{if .ToolCalls}}functools[{{end}}",
|
||||||
want: "functools",
|
// want: "functools",
|
||||||
ok: true,
|
// ok: true,
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
name: "tool call with angle brackets",
|
// name: "tool call with angle brackets",
|
||||||
template: "{{if .ToolCalls}}Hello, world! <tool_call>{{end}}",
|
// template: "{{if .ToolCalls}}Hello, world! <tool_call>{{end}}",
|
||||||
want: "<tool_call>",
|
// want: "<tool_call>",
|
||||||
ok: true,
|
// ok: true,
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
name: "multiple tool call formats",
|
// name: "multiple tool call formats",
|
||||||
template: "{{if .ToolCalls}}[tool_call] <tool_call>{{end}}",
|
// template: "{{if .ToolCalls}}[tool_call] <tool_call>{{end}}",
|
||||||
want: "[tool_call]",
|
// want: "[tool_call]",
|
||||||
ok: true,
|
// ok: true,
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
name: "single angle bracket tool call",
|
// name: "single angle bracket tool call",
|
||||||
template: "{{if .ToolCalls}}<tool_call>{{end}}",
|
// template: "{{if .ToolCalls}}<tool_call>{{end}}",
|
||||||
want: "<tool_call>",
|
// want: "<tool_call>",
|
||||||
ok: true,
|
// ok: true,
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
name: "incomplete angle bracket after tool call",
|
// name: "incomplete angle bracket after tool call",
|
||||||
template: "{{if .ToolCalls}}[tool_call] <{{end}}",
|
// template: "{{if .ToolCalls}}[tool_call] <{{end}}",
|
||||||
want: "[tool_call]",
|
// want: "[tool_call]",
|
||||||
ok: true,
|
// ok: true,
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
name: "angle bracket prefix with tool call",
|
// name: "angle bracket prefix with tool call",
|
||||||
template: "{{if .ToolCalls}}> <tool_call>{{end}}",
|
// template: "{{if .ToolCalls}}> <tool_call>{{end}}",
|
||||||
want: "<tool_call>",
|
// want: "<tool_call>",
|
||||||
ok: true,
|
// ok: true,
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
name: "uppercase tool call with incomplete bracket",
|
// name: "uppercase tool call with incomplete bracket",
|
||||||
template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}",
|
// template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}",
|
||||||
want: "[TOOL_CALL]",
|
// want: "[TOOL_CALL]",
|
||||||
ok: true,
|
// ok: true,
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
name: "uppercase tool call with adjacent bracket",
|
// name: "uppercase tool call with adjacent bracket",
|
||||||
template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}",
|
// template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}",
|
||||||
want: "[TOOL_CALL]",
|
// want: "[TOOL_CALL]",
|
||||||
ok: true,
|
// ok: true,
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
name: "tool call with pipe delimiters",
|
// name: "tool call with pipe delimiters",
|
||||||
template: "{{if .ToolCalls}}<|tool_call|>{{end}}",
|
// template: "{{if .ToolCalls}}<|tool_call|>{{end}}",
|
||||||
want: "<|tool_call|>",
|
// want: "<|tool_call|>",
|
||||||
ok: true,
|
// ok: true,
|
||||||
},
|
// },
|
||||||
}
|
// }
|
||||||
|
|
||||||
for _, tt := range cases {
|
// for _, tt := range cases {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
// t.Run(tt.name, func(t *testing.T) {
|
||||||
tmpl, err := gotmpl.New("test").Parse(tt.template)
|
// tmpl, err := gotmpl.New("test").Parse(tt.template)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
t.Fatalf("failed to parse template: %v", err)
|
// t.Fatalf("failed to parse template: %v", err)
|
||||||
}
|
// }
|
||||||
got, ok := ToolPrefix(tmpl)
|
// got, ok := ToolPrefix(tmpl)
|
||||||
if got != tt.want {
|
// if got != tt.want {
|
||||||
t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want)
|
// t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want)
|
||||||
}
|
// }
|
||||||
if ok != tt.ok {
|
// if ok != tt.ok {
|
||||||
t.Errorf("ToolToken(%q) = %v; want %v", tt.template, ok, tt.ok)
|
// t.Errorf("ToolToken(%q) = %v; want %v", tt.template, ok, tt.ok)
|
||||||
}
|
// }
|
||||||
})
|
// })
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
func TestTextAfterToolCalls(t *testing.T) {
|
// func TestTextAfterToolCalls(t *testing.T) {
|
||||||
cases := []struct {
|
// cases := []struct {
|
||||||
name string
|
// name string
|
||||||
template string
|
// template string
|
||||||
want string
|
// want string
|
||||||
ok bool
|
// ok bool
|
||||||
}{
|
// }{
|
||||||
{
|
// {
|
||||||
name: "basic tool call with text after",
|
// name: "basic tool call with text after",
|
||||||
template: `{{if .ToolCalls}}tool response{{end}}`,
|
// template: `{{if .ToolCalls}}tool response{{end}}`,
|
||||||
want: "tool response",
|
// want: "tool response",
|
||||||
ok: true,
|
// ok: true,
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
name: "tool call with mixed content after",
|
// name: "tool call with mixed content after",
|
||||||
template: `{{if .ToolCalls}}<tool_call>{{.Something}}{{end}}`,
|
// template: `{{if .ToolCalls}}<tool_call>{{.Something}}{{end}}`,
|
||||||
want: "<tool_call>",
|
// want: "<tool_call>",
|
||||||
ok: true,
|
// ok: true,
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
name: "tool call with no text after",
|
// name: "tool call with no text after",
|
||||||
template: `{{if .ToolCalls}}{{.Something}}{{end}}`,
|
// template: `{{if .ToolCalls}}{{.Something}}{{end}}`,
|
||||||
want: "",
|
// want: "",
|
||||||
ok: true,
|
// ok: true,
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
name: "nested tool call",
|
// name: "nested tool call",
|
||||||
template: `{{if .Something}}{{if .ToolCalls}}[TOOL_CALL]{{end}}{{end}}`,
|
// template: `{{if .Something}}{{if .ToolCalls}}[TOOL_CALL]{{end}}{{end}}`,
|
||||||
want: "[TOOL_CALL]",
|
// want: "[TOOL_CALL]",
|
||||||
ok: true,
|
// ok: true,
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
name: "no tool calls",
|
// name: "no tool calls",
|
||||||
template: `{{if .Something}}no tools here{{end}}`,
|
// template: `{{if .Something}}no tools here{{end}}`,
|
||||||
want: "",
|
// want: "",
|
||||||
ok: false,
|
// ok: false,
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
name: "empty template",
|
// name: "empty template",
|
||||||
template: ``,
|
// template: ``,
|
||||||
want: "",
|
// want: "",
|
||||||
ok: false,
|
// ok: false,
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
name: "multiple tool calls sections",
|
// name: "multiple tool calls sections",
|
||||||
template: `{{if .ToolCalls}}first{{end}}{{if .ToolCalls}}second{{end}}`,
|
// template: `{{if .ToolCalls}}first{{end}}{{if .ToolCalls}}second{{end}}`,
|
||||||
want: "first",
|
// want: "first",
|
||||||
ok: true,
|
// ok: true,
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
name: "range over tool calls",
|
// name: "range over tool calls",
|
||||||
template: `{{if .ToolCalls}}{{range .ToolCalls}}tool{{end}}{{end}}`,
|
// template: `{{if .ToolCalls}}{{range .ToolCalls}}tool{{end}}{{end}}`,
|
||||||
want: "",
|
// want: "",
|
||||||
ok: true,
|
// ok: true,
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
name: "tool calls with pipe delimiters",
|
// name: "tool calls with pipe delimiters",
|
||||||
template: `{{if .ToolCalls}}<|tool|>{{end}}`,
|
// template: `{{if .ToolCalls}}<|tool|>{{end}}`,
|
||||||
want: "<|tool|>",
|
// want: "<|tool|>",
|
||||||
ok: true,
|
// ok: true,
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
name: "tool calls with nested template",
|
// name: "tool calls with nested template",
|
||||||
template: `{{if .ToolCalls}}{{template "tool" .}}{{end}}`,
|
// template: `{{if .ToolCalls}}{{template "tool" .}}{{end}}`,
|
||||||
want: "",
|
// want: "",
|
||||||
ok: true,
|
// ok: true,
|
||||||
},
|
// },
|
||||||
{
|
// {
|
||||||
name: "tool calls with whitespace variations",
|
// name: "tool calls with whitespace variations",
|
||||||
template: `{{if .ToolCalls}} tool {{end}}`,
|
// template: `{{if .ToolCalls}} tool {{end}}`,
|
||||||
want: " tool ",
|
// want: " tool ",
|
||||||
ok: true,
|
// ok: true,
|
||||||
},
|
// },
|
||||||
}
|
// }
|
||||||
|
|
||||||
for _, tt := range cases {
|
// for _, tt := range cases {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
// t.Run(tt.name, func(t *testing.T) {
|
||||||
tmpl, err := gotmpl.New("test").Parse(tt.template)
|
// tmpl, err := gotmpl.New("test").Parse(tt.template)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
t.Fatalf("failed to parse template: %v", err)
|
// t.Fatalf("failed to parse template: %v", err)
|
||||||
}
|
// }
|
||||||
|
|
||||||
got, ok := extractToolCallsTemplate(tmpl)
|
// got, ok := extractToolCallsTemplate(tmpl)
|
||||||
if got != tt.want {
|
// if got != tt.want {
|
||||||
t.Errorf("TextAfterToolCalls() got = %q, want %q", got, tt.want)
|
// t.Errorf("TextAfterToolCalls() got = %q, want %q", got, tt.want)
|
||||||
}
|
// }
|
||||||
if ok != tt.ok {
|
// if ok != tt.ok {
|
||||||
t.Errorf("TextAfterToolCalls() ok = %v, want %v", ok, tt.ok)
|
// t.Errorf("TextAfterToolCalls() ok = %v, want %v", ok, tt.ok)
|
||||||
}
|
// }
|
||||||
})
|
// })
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
@ -1483,19 +1483,21 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
slog.Debug("chat request", "images", len(images), "prompt", prompt)
|
||||||
|
|
||||||
|
var toolParser *tools.Parser
|
||||||
|
if len(req.Tools) > 0 {
|
||||||
|
toolParser, err = tools.NewParser(m.Template.Template)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to create tool parser", "error", err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ch := make(chan any)
|
ch := make(chan any)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
// ! personally not a fan of this pattern
|
|
||||||
toolTemplate, ok := ToolTemplate(m)
|
|
||||||
if !ok {
|
|
||||||
slog.Error("tool template not found", "model", m.Name)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var toolParser *tools.Parser
|
|
||||||
if len(req.Tools) > 0 {
|
|
||||||
toolParser = tools.NewParser(m.Template.Template, toolTemplate)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
@ -1523,30 +1525,20 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(req.Tools) > 0 && !toolParser.Done {
|
if len(req.Tools) > 0 && !toolParser.Done {
|
||||||
toolCalls, leftover := toolParser.ParseToolCalls(r.Content)
|
toolCalls, content, err := toolParser.Add(r.Content)
|
||||||
// * This can be abstracted again to a .handleState(tp.state)
|
if err == nil {
|
||||||
// * However, we'd need a flag to indicate whether to send the response or not
|
if len(content) > 0 {
|
||||||
// * happy to take whatever is more idiomatic
|
res.Message.Content = content
|
||||||
switch toolParser.ParserState {
|
fmt.Println("sending content in response", content)
|
||||||
case tools.ToolCallAccumulate:
|
} else if len(toolCalls) > 0 {
|
||||||
// tokens are accumulated in the tool parser
|
fmt.Println("sending tool calls in response", toolCalls)
|
||||||
return
|
res.Message.ToolCalls = toolCalls
|
||||||
case tools.ToolCallSendTokens:
|
res.Message.Content = ""
|
||||||
// tokens are sent back in the response
|
} else {
|
||||||
case tools.ToolCallSendPartial:
|
return
|
||||||
// tokens not needed for parsing are sent back in the response
|
|
||||||
if len(leftover) > 0 {
|
|
||||||
res.Message.Content = leftover
|
|
||||||
}
|
}
|
||||||
// ! state is needed as we need to not match on the other states
|
|
||||||
case tools.ToolCallFound:
|
|
||||||
res.Message.ToolCalls = toolCalls
|
|
||||||
res.Message.Content = ""
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println("sending response", res.Message.Content)
|
|
||||||
// * this is where we'd need the flag if we have a .handleState(tp.state)
|
|
||||||
ch <- res
|
ch <- res
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
|
441
tools/tools.go
441
tools/tools.go
@ -6,91 +6,28 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"slices"
|
|
||||||
"strings"
|
"strings"
|
||||||
gotmpl "text/template"
|
gotmpl "text/template"
|
||||||
"text/template/parse"
|
|
||||||
|
|
||||||
jsonv2 "github.com/go-json-experiment/json"
|
jsonv2 "github.com/go-json-experiment/json"
|
||||||
jsontext "github.com/go-json-experiment/json/jsontext"
|
jsontext "github.com/go-json-experiment/json/jsontext"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
type State int
|
|
||||||
|
|
||||||
// TODO: potentially coalesce states
|
|
||||||
const (
|
|
||||||
SendTokens State = iota
|
|
||||||
GreedyToolWithPrefix
|
|
||||||
GreedyToolNoPrefix
|
|
||||||
ForceTools
|
|
||||||
ToolSuffix
|
|
||||||
ContainsPrefix
|
|
||||||
PartialPrefix
|
|
||||||
NotPartialPrefix
|
|
||||||
Done
|
|
||||||
)
|
|
||||||
|
|
||||||
type ExternalState int
|
|
||||||
|
|
||||||
const (
|
|
||||||
ToolCallFound ExternalState = iota
|
|
||||||
ToolCallSendPartial
|
|
||||||
ToolCallAccumulate
|
|
||||||
ToolCallSendTokens
|
|
||||||
)
|
|
||||||
|
|
||||||
func (s ExternalState) String() string {
|
|
||||||
switch s {
|
|
||||||
case ToolCallFound:
|
|
||||||
return "ToolCallFound"
|
|
||||||
case ToolCallSendPartial:
|
|
||||||
return "ToolCallSendPartial"
|
|
||||||
case ToolCallAccumulate:
|
|
||||||
return "ToolCallAccumulate"
|
|
||||||
case ToolCallSendTokens:
|
|
||||||
return "ToolCallSendTokens"
|
|
||||||
default:
|
|
||||||
return fmt.Sprintf("Unknown ExternalState (%d)", s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s State) String() string {
|
|
||||||
switch s {
|
|
||||||
case SendTokens:
|
|
||||||
return "SendTokens"
|
|
||||||
case GreedyToolWithPrefix:
|
|
||||||
return "GreedyToolWithPrefix"
|
|
||||||
case GreedyToolNoPrefix:
|
|
||||||
return "GreedyToolNoPrefix"
|
|
||||||
case ForceTools:
|
|
||||||
return "ForceTools"
|
|
||||||
case ToolSuffix:
|
|
||||||
return "ToolSuffix"
|
|
||||||
case PartialPrefix:
|
|
||||||
return "PossiblePrefix"
|
|
||||||
case Done:
|
|
||||||
return "Done"
|
|
||||||
case ContainsPrefix:
|
|
||||||
return "PartialPrefix"
|
|
||||||
default:
|
|
||||||
return fmt.Sprintf("Unknown State (%d)", s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: simplify if possible
|
// TODO: simplify if possible
|
||||||
type Parser struct {
|
type Parser struct {
|
||||||
tmpl *gotmpl.Template
|
greedy bool
|
||||||
state State
|
prefixFound bool
|
||||||
sb *strings.Builder
|
partialPrefix bool
|
||||||
toolPrefix string
|
tmpl *gotmpl.Template
|
||||||
toolIndex int
|
sb *strings.Builder
|
||||||
ParserState ExternalState
|
prefix string
|
||||||
Done bool
|
index int
|
||||||
|
Done bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ? move to a separate file
|
|
||||||
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
||||||
// Returns parsed tool calls, a boolean indicating if the JSON is incomplete, and a boolean indicating if the tool calls were found
|
// Returns parsed tool calls, a boolean indicating if the JSON is incomplete, and a boolean indicating if the tool calls were found
|
||||||
func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
||||||
@ -222,314 +159,126 @@ func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
|||||||
return toolCalls, false, true
|
return toolCalls, false, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: clean up the boundary of internal and external state transitions
|
// prefix stripped string if any, prefix found, and if we should accumulate
|
||||||
func (p *Parser) updateStateAfterJSONParse(ok bool, partial bool, tcs []api.ToolCall) {
|
func (p *Parser) checkPrefix(s string) (string, bool, bool) {
|
||||||
fmt.Printf("updating output state: ok=%v partial=%v tool_calls=%d current_state=%s\n", ok, partial, len(tcs), p.state)
|
|
||||||
|
|
||||||
// state transition logic
|
if p.prefix == "" {
|
||||||
switch {
|
return s, false, true
|
||||||
case !ok && !partial && p.state == ForceTools:
|
|
||||||
// force partial tool if we have a prefix
|
|
||||||
// no op and stay in force tools
|
|
||||||
p.sb.Reset()
|
|
||||||
case !ok && !partial:
|
|
||||||
if p.state == GreedyToolNoPrefix {
|
|
||||||
p.state = Done
|
|
||||||
// ? the output parser state is the same even though internal can we not leak the external state?
|
|
||||||
p.Done = true
|
|
||||||
}
|
|
||||||
if p.state == GreedyToolWithPrefix {
|
|
||||||
p.state = SendTokens
|
|
||||||
}
|
|
||||||
if p.state == PartialPrefix {
|
|
||||||
p.state = NotPartialPrefix
|
|
||||||
}
|
|
||||||
case !ok && partial:
|
|
||||||
// acucumulate
|
|
||||||
|
|
||||||
case len(tcs) > 0:
|
|
||||||
// do not parse again in the greedy JSON case as soon as we have a tool call
|
|
||||||
p.sb.Reset()
|
|
||||||
}
|
|
||||||
p.updateExternalState(tcs)
|
|
||||||
fmt.Printf("state updated: new_state=%s parser_state=%s\n", p.state, p.ParserState)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Parser) updateExternalState(tcs []api.ToolCall) {
|
|
||||||
fmt.Printf("updating external state: current_state=%s tool_calls=%d\n", p.state, len(tcs))
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case len(tcs) > 0:
|
|
||||||
// do not parse again in the greedy JSON case as soon as we have a tool call
|
|
||||||
if p.state == GreedyToolWithPrefix {
|
|
||||||
p.state = SendTokens
|
|
||||||
} else if p.state == GreedyToolNoPrefix {
|
|
||||||
p.state = Done
|
|
||||||
p.Done = true
|
|
||||||
}
|
|
||||||
p.ParserState = ToolCallFound
|
|
||||||
case p.state == GreedyToolWithPrefix || p.state == GreedyToolNoPrefix ||
|
|
||||||
p.state == ToolSuffix || p.state == PartialPrefix ||
|
|
||||||
(p.state == ForceTools && len(tcs) == 0):
|
|
||||||
p.ParserState = ToolCallAccumulate
|
|
||||||
case p.state == ContainsPrefix:
|
|
||||||
p.ParserState = ToolCallSendPartial
|
|
||||||
case p.state == SendTokens || p.state == Done:
|
|
||||||
p.ParserState = ToolCallSendTokens
|
|
||||||
case p.state == NotPartialPrefix:
|
|
||||||
p.ParserState = ToolCallSendPartial
|
|
||||||
default:
|
|
||||||
p.ParserState = ToolCallSendTokens
|
|
||||||
p.sb.Reset()
|
|
||||||
p.state = SendTokens
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// string, and if it has a prefix
|
|
||||||
func (p *Parser) checkPrefix(s string) (string, bool) {
|
|
||||||
fmt.Printf("checking prefix: input=%s prefix=%s\n", s, p.toolPrefix)
|
|
||||||
|
|
||||||
if p.toolPrefix == "" {
|
|
||||||
return s, true
|
|
||||||
}
|
}
|
||||||
original := s
|
original := s
|
||||||
s, hasPrefix := strings.CutPrefix(s, p.toolPrefix)
|
s = strings.TrimSpace(s)
|
||||||
|
s, hasPrefix := strings.CutPrefix(s, p.prefix)
|
||||||
if hasPrefix {
|
if hasPrefix {
|
||||||
p.state = ForceTools
|
|
||||||
fmt.Printf("found exact prefix match: remaining=%s\n", s)
|
|
||||||
// partial tool possibly - accumulate
|
// partial tool possibly - accumulate
|
||||||
} else if suffixOverlap(s, p.toolPrefix) > 0 {
|
return s, true, true
|
||||||
p.state = PartialPrefix
|
} else if overlap := suffixOverlap(original, p.prefix); overlap > 0 {
|
||||||
fmt.Printf("found partial prefix: remaining=%s\n", s)
|
// p.state = PartialPrefix
|
||||||
return "", false
|
p.partialPrefix = true
|
||||||
// the case where "token<tool_call>" - send "token" back
|
return original[0 : len(original)-overlap], false, false
|
||||||
|
} else if idx := strings.Index(original, p.prefix); idx != -1 {
|
||||||
|
// Found prefix in middle of string, keep only content before prefix
|
||||||
// accounts for spaces in prefix or suffix to avoid breaking cache
|
// accounts for spaces in prefix or suffix to avoid breaking cache
|
||||||
} else if strings.Contains(original, p.toolPrefix) {
|
p.partialPrefix = true
|
||||||
idx := strings.Index(original, p.toolPrefix)
|
p.sb.Reset()
|
||||||
if idx != -1 {
|
|
||||||
// still keeps the prefix
|
p.sb.WriteString(strings.TrimSpace(original[idx:]))
|
||||||
p.state = ContainsPrefix
|
return original[:idx], false, false
|
||||||
p.sb.Reset()
|
|
||||||
// todo: see if there is a simpler way for this
|
|
||||||
idx2 := strings.Index(s, p.toolPrefix)
|
|
||||||
// buffer now only has the prefix
|
|
||||||
p.sb.WriteString(s[idx2:])
|
|
||||||
fmt.Printf("found prefix in middle: prefix_start=%d content_before=%s\n", idx, original[:idx])
|
|
||||||
return original[:idx], false
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return s, true
|
p.partialPrefix = false
|
||||||
|
return s, false, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: simplify the flow of this function
|
func (p *Parser) Add(s string) (tools []api.ToolCall, content string, err error) {
|
||||||
// ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing.
|
slog.Debug("adding tool calls", "input", s)
|
||||||
// Returns tool calls, whether parsing is incomplete, and any errors.
|
|
||||||
func (p *Parser) ParseToolCalls(s string) ([]api.ToolCall, string) {
|
|
||||||
fmt.Printf("parsing tool calls: input=%s current_state=%s\n", s, p.state)
|
|
||||||
|
|
||||||
p.sb.WriteString(s)
|
p.sb.WriteString(s)
|
||||||
s = p.sb.String()
|
s = p.sb.String()
|
||||||
|
|
||||||
s = strings.TrimSpace(s)
|
|
||||||
|
|
||||||
if len(s) == 0 {
|
if len(s) == 0 {
|
||||||
p.updateExternalState(nil)
|
return nil, "", nil
|
||||||
return nil, ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s, cont := p.checkPrefix(s)
|
s, prefixFound, cont := p.checkPrefix(s)
|
||||||
|
|
||||||
if !cont {
|
if !cont {
|
||||||
p.updateExternalState(nil)
|
if s != "" {
|
||||||
if p.state == ContainsPrefix {
|
// send only the content back, prefix exists
|
||||||
fmt.Printf("returning partial prefix: remaining=%s\n", s)
|
return nil, s, nil
|
||||||
return nil, s
|
|
||||||
}
|
}
|
||||||
// * we'd be returning here for just accumulating with possible prefix
|
// accumulate case
|
||||||
// * ext state is accumulation
|
return nil, "", nil
|
||||||
return nil, ""
|
|
||||||
}
|
}
|
||||||
// * lets say the check fails here and now we're still in external state accumulation here
|
|
||||||
|
|
||||||
// stay in SendTokens unless we have a prefix
|
// circuit breaker
|
||||||
if p.state == SendTokens {
|
if prefixFound {
|
||||||
p.updateExternalState(nil)
|
p.prefixFound = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// for cases with a prefix in template
|
||||||
|
if p.prefix != "" && !p.greedy && !p.prefixFound {
|
||||||
|
// send tokens down
|
||||||
p.sb.Reset()
|
p.sb.Reset()
|
||||||
fmt.Printf("returning send tokens: remaining=%s\n", s)
|
return nil, "", errors.New("prefix not found")
|
||||||
return nil, s
|
|
||||||
}
|
}
|
||||||
|
// we have a prefix or are in json mode
|
||||||
// * we'd parse here as json to see if it's a tool call
|
|
||||||
tcs, partial, ok := p.parseJSONToolCalls(s)
|
tcs, partial, ok := p.parseJSONToolCalls(s)
|
||||||
// * it would not be a tool call here
|
if partial {
|
||||||
p.updateStateAfterJSONParse(ok, partial, tcs)
|
// accumulate case
|
||||||
if !ok {
|
return nil, "", nil
|
||||||
// * and so we should send the data here
|
|
||||||
// * we also need to move out of that internal state after sending the tokens
|
|
||||||
if p.state == NotPartialPrefix {
|
|
||||||
p.state = SendTokens
|
|
||||||
// the string would have acc until here
|
|
||||||
return nil, p.sb.String()
|
|
||||||
}
|
|
||||||
return nil, ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.greedy = false
|
||||||
|
if !ok {
|
||||||
|
// will not be a partial at this point
|
||||||
|
p.sb.Reset()
|
||||||
|
// send tokens
|
||||||
|
if p.prefix == "" {
|
||||||
|
p.Done = true
|
||||||
|
}
|
||||||
|
if p.prefixFound {
|
||||||
|
// drop tokens instead - sb is reset, no tokens sent to user
|
||||||
|
return nil, "", nil
|
||||||
|
}
|
||||||
|
return nil, "", errors.New("failed to parse tool calls")
|
||||||
|
}
|
||||||
|
|
||||||
for _, tc := range tcs {
|
for _, tc := range tcs {
|
||||||
tc.Function.Index = p.toolIndex
|
tc.Function.Index = p.index
|
||||||
p.toolIndex++
|
p.index++
|
||||||
}
|
}
|
||||||
fmt.Printf("finished parsing tool calls: tool_calls_found=%d\n", len(tcs))
|
if p.prefix == "" {
|
||||||
return tcs, ""
|
p.Done = true
|
||||||
|
}
|
||||||
|
p.sb.Reset()
|
||||||
|
return tcs, "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func suffixOverlap(s, delim string) int {
|
func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) {
|
||||||
max := min(len(delim), len(s))
|
parsedTemplate, err := template.Parse(templateToProcess.Root.String())
|
||||||
for i := max; i > 0; i-- {
|
if err != nil {
|
||||||
if strings.HasSuffix(s, delim[:i]) {
|
return nil, err
|
||||||
return i
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return 0
|
if parsedTemplate == nil {
|
||||||
}
|
return nil, errors.New("failed to parse template")
|
||||||
|
|
||||||
// extractToolCallsTemplate finds the immediate following text after any IfNode containing ".ToolCalls"
|
|
||||||
func extractToolCallsTemplate(tmpl *gotmpl.Template) (string, bool) {
|
|
||||||
if tmpl == nil || tmpl.Tree == nil {
|
|
||||||
slog.Debug("TextAfterToolCalls: template or tree is nil")
|
|
||||||
return "", false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var result string
|
toolCallTemplate, hasToolCalls := toolTemplate(parsedTemplate)
|
||||||
var found bool
|
if !hasToolCalls {
|
||||||
|
return nil, errors.New("failed to find tool template")
|
||||||
var walk func(nodes []parse.Node)
|
}
|
||||||
walk = func(nodes []parse.Node) {
|
if toolCallTemplate == nil {
|
||||||
for _, node := range nodes {
|
return nil, errors.New("failed to find tool template")
|
||||||
if found {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch n := node.(type) {
|
|
||||||
case *parse.IfNode:
|
|
||||||
if nodeContainsToolCalls(n) {
|
|
||||||
// Collect immediate TextNode(s) at start of IfNode's list
|
|
||||||
var sb strings.Builder
|
|
||||||
for _, innerNode := range n.List.Nodes {
|
|
||||||
if tn, ok := innerNode.(*parse.TextNode); ok {
|
|
||||||
sb.Write(tn.Text)
|
|
||||||
} else {
|
|
||||||
// Stop at first non-text node
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result = sb.String()
|
|
||||||
found = true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Recurse into child nodes
|
|
||||||
walk(n.List.Nodes)
|
|
||||||
if n.ElseList != nil {
|
|
||||||
walk(n.ElseList.Nodes)
|
|
||||||
}
|
|
||||||
case *parse.ListNode:
|
|
||||||
walk(n.Nodes)
|
|
||||||
case *parse.RangeNode:
|
|
||||||
walk(n.List.Nodes)
|
|
||||||
if n.ElseList != nil {
|
|
||||||
walk(n.ElseList.Nodes)
|
|
||||||
}
|
|
||||||
case *parse.WithNode:
|
|
||||||
walk(n.List.Nodes)
|
|
||||||
if n.ElseList != nil {
|
|
||||||
walk(n.ElseList.Nodes)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
// Continue to next node
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if found {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
walk(tmpl.Tree.Root.Nodes)
|
toolPrefix, _ := ToolPrefix(templateToProcess)
|
||||||
return result, found
|
toolPrefix = strings.TrimSpace(toolPrefix)
|
||||||
}
|
|
||||||
|
|
||||||
// Helper to detect if a node's condition includes ".ToolCalls"
|
fmt.Printf("creating new tool parser: prefix=%s\n", toolPrefix)
|
||||||
func nodeContainsToolCalls(n *parse.IfNode) bool {
|
|
||||||
for _, cmd := range n.Pipe.Cmds {
|
|
||||||
for _, arg := range cmd.Args {
|
|
||||||
if field, ok := arg.(*parse.FieldNode); ok {
|
|
||||||
if slices.Contains(field.Ident, "ToolCalls") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func ToolPrefix(tmpl *gotmpl.Template) (string, bool) {
|
|
||||||
tokenText, ok := extractToolCallsTemplate(tmpl)
|
|
||||||
if !ok {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
tokenText = strings.TrimSpace(tokenText)
|
|
||||||
if tokenText == "" {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
first := strings.Fields(tokenText)[0]
|
|
||||||
|
|
||||||
start := -1
|
|
||||||
end := -1
|
|
||||||
for i, r := range tokenText {
|
|
||||||
if r == '<' || r == '[' {
|
|
||||||
start = i
|
|
||||||
}
|
|
||||||
if (r == '>' || r == ']') && start != -1 {
|
|
||||||
end = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if start != -1 && end != -1 {
|
|
||||||
// return the token including the [ or < and the ] or >
|
|
||||||
return tokenText[start : end+1], true
|
|
||||||
} else if start != -1 {
|
|
||||||
// get until the [ or < - in the case tag was not closed
|
|
||||||
return tokenText[:start], true
|
|
||||||
} else if end != -1 {
|
|
||||||
// get after the ] or > - in the case tag was not opened
|
|
||||||
return tokenText[end+1:], true
|
|
||||||
}
|
|
||||||
return first, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewParser(tmpl *gotmpl.Template, toolTemplate *gotmpl.Template) *Parser {
|
|
||||||
// TODO: use new template parsing to get all tokens for the prefix
|
|
||||||
if tmpl == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if toolTemplate == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
prefix, _ := ToolPrefix(tmpl)
|
|
||||||
prefix = strings.TrimSpace(prefix)
|
|
||||||
|
|
||||||
var state State
|
|
||||||
if prefix == "" {
|
|
||||||
state = GreedyToolNoPrefix
|
|
||||||
} else {
|
|
||||||
state = GreedyToolWithPrefix
|
|
||||||
}
|
|
||||||
fmt.Printf("creating new tool parser: prefix=%s initial_state=%s\n", prefix, state)
|
|
||||||
return &Parser{
|
return &Parser{
|
||||||
tmpl: toolTemplate,
|
tmpl: toolCallTemplate,
|
||||||
sb: &strings.Builder{},
|
sb: &strings.Builder{},
|
||||||
toolPrefix: prefix,
|
prefix: toolPrefix,
|
||||||
state: state,
|
greedy: true,
|
||||||
ParserState: ToolCallAccumulate,
|
}, nil
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -239,14 +239,14 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
model: "qwen3",
|
model: "qwen3",
|
||||||
output: `<think></think>< fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
output: `<think></think>< fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
expectedToolCall: []api.ToolCall{},
|
expectedToolCall: []api.ToolCall{},
|
||||||
expectedTokens: `<think></think>< fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
expectedTokens: `<think></think> fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "qwen3 invalid tool call with partial tool prefix (multiple rune suffix match)",
|
name: "qwen3 invalid tool call with partial tool prefix (multiple rune suffix match)",
|
||||||
model: "qwen3",
|
model: "qwen3",
|
||||||
output: `<think></think><tool_c fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
output: `<think></think><tool_c fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
expectedToolCall: []api.ToolCall{},
|
expectedToolCall: []api.ToolCall{},
|
||||||
expectedTokens: `<think></think><tool_c fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
expectedTokens: `<think></think> fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "qwen3 invalid tool call with malformed tool prefix",
|
name: "qwen3 invalid tool call with malformed tool prefix",
|
||||||
@ -315,34 +315,31 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("parse", func(t *testing.T) {
|
t.Run("parse", func(t *testing.T) {
|
||||||
// fmt.Printf("tmpl: %s\n", tmpl.Root.String())
|
// fmt.Printf("tmpl: %s\n", tmpl.Root.String())
|
||||||
toolTemplate, ok := toolTemplateHelper(t, tmpl)
|
tp, err := NewParser(tmpl.Template)
|
||||||
if !ok {
|
if err != nil {
|
||||||
t.Fatalf("tool template not found for model %s", tt.model)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
tp := NewParser(tmpl.Template, toolTemplate)
|
|
||||||
got := []api.ToolCall{}
|
got := []api.ToolCall{}
|
||||||
var gotTokens strings.Builder
|
var gotTokens strings.Builder
|
||||||
|
|
||||||
|
var add bool
|
||||||
tokens := strings.Fields(tt.output)
|
tokens := strings.Fields(tt.output)
|
||||||
for _, tok := range tokens {
|
for _, tok := range tokens {
|
||||||
add := true
|
|
||||||
s := " " + tok
|
s := " " + tok
|
||||||
|
|
||||||
|
add = true
|
||||||
if !tp.Done {
|
if !tp.Done {
|
||||||
toolCalls, leftover := tp.ParseToolCalls(s)
|
toolCalls, content, err := tp.Add(s)
|
||||||
switch tp.ParserState {
|
if err == nil {
|
||||||
case ToolCallFound:
|
if content != "" {
|
||||||
got = append(got, toolCalls...)
|
gotTokens.WriteString(content)
|
||||||
add = false
|
add = false
|
||||||
case ToolCallSendTokens:
|
} else if len(toolCalls) > 0 {
|
||||||
gotTokens.WriteString(s)
|
got = append(got, toolCalls...)
|
||||||
add = false
|
add = false
|
||||||
case ToolCallAccumulate:
|
} else {
|
||||||
add = false
|
add = false
|
||||||
case ToolCallSendPartial:
|
}
|
||||||
t.Log("send partial", "leftover", leftover)
|
|
||||||
gotTokens.WriteString(" " + leftover)
|
|
||||||
add = false
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if add {
|
if add {
|
||||||
|
155
tools/utils.go
Normal file
155
tools/utils.go
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
gotmpl "text/template"
|
||||||
|
"text/template/parse"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
|
)
|
||||||
|
|
||||||
|
// extractToolCallsTemplate finds the immediate following text after any IfNode containing ".ToolCalls"
|
||||||
|
func extractToolCallsTemplate(tmpl *gotmpl.Template) (string, bool) {
|
||||||
|
if tmpl == nil || tmpl.Tree == nil {
|
||||||
|
slog.Debug("TextAfterToolCalls: template or tree is nil")
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
var result string
|
||||||
|
var found bool
|
||||||
|
|
||||||
|
var walk func(nodes []parse.Node)
|
||||||
|
walk = func(nodes []parse.Node) {
|
||||||
|
for _, node := range nodes {
|
||||||
|
if found {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch n := node.(type) {
|
||||||
|
case *parse.IfNode:
|
||||||
|
if nodeContainsToolCalls(n) {
|
||||||
|
// Collect immediate TextNode(s) at start of IfNode's list
|
||||||
|
var sb strings.Builder
|
||||||
|
for _, innerNode := range n.List.Nodes {
|
||||||
|
if tn, ok := innerNode.(*parse.TextNode); ok {
|
||||||
|
sb.Write(tn.Text)
|
||||||
|
} else {
|
||||||
|
// Stop at first non-text node
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result = sb.String()
|
||||||
|
found = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Recurse into child nodes
|
||||||
|
walk(n.List.Nodes)
|
||||||
|
if n.ElseList != nil {
|
||||||
|
walk(n.ElseList.Nodes)
|
||||||
|
}
|
||||||
|
case *parse.ListNode:
|
||||||
|
walk(n.Nodes)
|
||||||
|
case *parse.RangeNode:
|
||||||
|
walk(n.List.Nodes)
|
||||||
|
if n.ElseList != nil {
|
||||||
|
walk(n.ElseList.Nodes)
|
||||||
|
}
|
||||||
|
case *parse.WithNode:
|
||||||
|
walk(n.List.Nodes)
|
||||||
|
if n.ElseList != nil {
|
||||||
|
walk(n.ElseList.Nodes)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Continue to next node
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if found {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
walk(tmpl.Tree.Root.Nodes)
|
||||||
|
return result, found
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to detect if a node's condition includes ".ToolCalls"
|
||||||
|
func nodeContainsToolCalls(n *parse.IfNode) bool {
|
||||||
|
for _, cmd := range n.Pipe.Cmds {
|
||||||
|
for _, arg := range cmd.Args {
|
||||||
|
if field, ok := arg.(*parse.FieldNode); ok {
|
||||||
|
if slices.Contains(field.Ident, "ToolCalls") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolPrefix returns the prefix for the tool call if it exists
|
||||||
|
// TODO(parthsareen): get full prefix from the template instead of just the first token
|
||||||
|
func ToolPrefix(tmpl *gotmpl.Template) (string, bool) {
|
||||||
|
tokenText, ok := extractToolCallsTemplate(tmpl)
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
tokenText = strings.TrimSpace(tokenText)
|
||||||
|
if tokenText == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
first := strings.Fields(tokenText)[0]
|
||||||
|
|
||||||
|
start := -1
|
||||||
|
end := -1
|
||||||
|
for i, r := range tokenText {
|
||||||
|
if r == '<' || r == '[' {
|
||||||
|
start = i
|
||||||
|
}
|
||||||
|
if (r == '>' || r == ']') && start != -1 {
|
||||||
|
end = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if start != -1 && end != -1 {
|
||||||
|
// return the token including the [ or < and the ] or >
|
||||||
|
return tokenText[start : end+1], true
|
||||||
|
} else if start != -1 {
|
||||||
|
// get until the [ or < - in the case tag was not closed
|
||||||
|
return tokenText[:start], true
|
||||||
|
} else if end != -1 {
|
||||||
|
// get after the ] or > - in the case tag was not opened
|
||||||
|
return tokenText[end+1:], true
|
||||||
|
}
|
||||||
|
return first, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func toolTemplate(t *template.Template) (*gotmpl.Template, bool) {
|
||||||
|
// create a subtree from the node that ranges over .ToolCalls
|
||||||
|
tmpl := t.Subtree(func(n parse.Node) bool {
|
||||||
|
if t, ok := n.(*parse.RangeNode); ok {
|
||||||
|
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
|
||||||
|
if tmpl == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return tmpl, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func suffixOverlap(s, delim string) int {
|
||||||
|
max := min(len(delim), len(s))
|
||||||
|
for i := max; i > 0; i-- {
|
||||||
|
if strings.HasSuffix(s, delim[:i]) {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
185
tools/utils_test.go
Normal file
185
tools/utils_test.go
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
gotmpl "text/template"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestToolPrefix(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
template string
|
||||||
|
want string
|
||||||
|
ok bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "basic tool call with action prefix",
|
||||||
|
template: "{{if .ToolCalls}}Action: ```json{{end}}",
|
||||||
|
want: "Action:",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "incomplete functools bracket",
|
||||||
|
template: "{{if .ToolCalls}}functools[{{end}}",
|
||||||
|
want: "functools",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with angle brackets",
|
||||||
|
template: "{{if .ToolCalls}}Hello, world! <tool_call>{{end}}",
|
||||||
|
want: "<tool_call>",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple tool call formats",
|
||||||
|
template: "{{if .ToolCalls}}[tool_call] <tool_call>{{end}}",
|
||||||
|
want: "[tool_call]",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single angle bracket tool call",
|
||||||
|
template: "{{if .ToolCalls}}<tool_call>{{end}}",
|
||||||
|
want: "<tool_call>",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "incomplete angle bracket after tool call",
|
||||||
|
template: "{{if .ToolCalls}}[tool_call] <{{end}}",
|
||||||
|
want: "[tool_call]",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "angle bracket prefix with tool call",
|
||||||
|
template: "{{if .ToolCalls}}> <tool_call>{{end}}",
|
||||||
|
want: "<tool_call>",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "uppercase tool call with incomplete bracket",
|
||||||
|
template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}",
|
||||||
|
want: "[TOOL_CALL]",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "uppercase tool call with adjacent bracket",
|
||||||
|
template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}",
|
||||||
|
want: "[TOOL_CALL]",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with pipe delimiters",
|
||||||
|
template: "{{if .ToolCalls}}<|tool_call|>{{end}}",
|
||||||
|
want: "<|tool_call|>",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tmpl, err := gotmpl.New("test").Parse(tt.template)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
got, ok := ToolPrefix(tmpl)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want)
|
||||||
|
}
|
||||||
|
if ok != tt.ok {
|
||||||
|
t.Errorf("ToolToken(%q) = %v; want %v", tt.template, ok, tt.ok)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTextAfterToolCalls(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
template string
|
||||||
|
want string
|
||||||
|
ok bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "basic tool call with text after",
|
||||||
|
template: `{{if .ToolCalls}}tool response{{end}}`,
|
||||||
|
want: "tool response",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with mixed content after",
|
||||||
|
template: `{{if .ToolCalls}}<tool_call>{{.Something}}{{end}}`,
|
||||||
|
want: "<tool_call>",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with no text after",
|
||||||
|
template: `{{if .ToolCalls}}{{.Something}}{{end}}`,
|
||||||
|
want: "",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested tool call",
|
||||||
|
template: `{{if .Something}}{{if .ToolCalls}}[TOOL_CALL]{{end}}{{end}}`,
|
||||||
|
want: "[TOOL_CALL]",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no tool calls",
|
||||||
|
template: `{{if .Something}}no tools here{{end}}`,
|
||||||
|
want: "",
|
||||||
|
ok: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty template",
|
||||||
|
template: ``,
|
||||||
|
want: "",
|
||||||
|
ok: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple tool calls sections",
|
||||||
|
template: `{{if .ToolCalls}}first{{end}}{{if .ToolCalls}}second{{end}}`,
|
||||||
|
want: "first",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "range over tool calls",
|
||||||
|
template: `{{if .ToolCalls}}{{range .ToolCalls}}tool{{end}}{{end}}`,
|
||||||
|
want: "",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool calls with pipe delimiters",
|
||||||
|
template: `{{if .ToolCalls}}<|tool|>{{end}}`,
|
||||||
|
want: "<|tool|>",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool calls with nested template",
|
||||||
|
template: `{{if .ToolCalls}}{{template "tool" .}}{{end}}`,
|
||||||
|
want: "",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool calls with whitespace variations",
|
||||||
|
template: `{{if .ToolCalls}} tool {{end}}`,
|
||||||
|
want: " tool ",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tmpl, err := gotmpl.New("test").Parse(tt.template)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, ok := extractToolCallsTemplate(tmpl)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("TextAfterToolCalls() got = %q, want %q", got, tt.want)
|
||||||
|
}
|
||||||
|
if ok != tt.ok {
|
||||||
|
t.Errorf("TextAfterToolCalls() ok = %v, want %v", ok, tt.ok)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user