renaming and splitting stuff up
This commit is contained in:
parent
b8b9c0c7cf
commit
4059b8db01
125
server/model.go
125
server/model.go
@ -11,7 +11,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
|
||||||
gotmpl "text/template"
|
gotmpl "text/template"
|
||||||
"text/template/parse"
|
"text/template/parse"
|
||||||
|
|
||||||
@ -130,130 +129,6 @@ func detectContentType(r io.Reader) (string, error) {
|
|||||||
return "unknown", nil
|
return "unknown", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractToolCallsTemplate finds the immediate following text after any IfNode containing ".ToolCalls"
|
|
||||||
func extractToolCallsTemplate(tmpl *gotmpl.Template) (string, bool) {
|
|
||||||
if tmpl == nil || tmpl.Tree == nil {
|
|
||||||
slog.Debug("TextAfterToolCalls: template or tree is nil")
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
var result string
|
|
||||||
var found bool
|
|
||||||
|
|
||||||
var walk func(nodes []parse.Node)
|
|
||||||
walk = func(nodes []parse.Node) {
|
|
||||||
for _, node := range nodes {
|
|
||||||
if found {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch n := node.(type) {
|
|
||||||
case *parse.IfNode:
|
|
||||||
if nodeContainsToolCalls(n) {
|
|
||||||
// Collect immediate TextNode(s) at start of IfNode's list
|
|
||||||
var sb strings.Builder
|
|
||||||
for _, innerNode := range n.List.Nodes {
|
|
||||||
if tn, ok := innerNode.(*parse.TextNode); ok {
|
|
||||||
sb.Write(tn.Text)
|
|
||||||
} else {
|
|
||||||
// Stop at first non-text node
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result = sb.String()
|
|
||||||
found = true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Recurse into child nodes
|
|
||||||
walk(n.List.Nodes)
|
|
||||||
if n.ElseList != nil {
|
|
||||||
walk(n.ElseList.Nodes)
|
|
||||||
}
|
|
||||||
case *parse.ListNode:
|
|
||||||
walk(n.Nodes)
|
|
||||||
case *parse.RangeNode:
|
|
||||||
walk(n.List.Nodes)
|
|
||||||
if n.ElseList != nil {
|
|
||||||
walk(n.ElseList.Nodes)
|
|
||||||
}
|
|
||||||
case *parse.WithNode:
|
|
||||||
walk(n.List.Nodes)
|
|
||||||
if n.ElseList != nil {
|
|
||||||
walk(n.ElseList.Nodes)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
// Continue to next node
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if found {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
walk(tmpl.Tree.Root.Nodes)
|
|
||||||
return result, found
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper to detect if a node's condition includes ".ToolCalls"
|
|
||||||
func nodeContainsToolCalls(n *parse.IfNode) bool {
|
|
||||||
for _, cmd := range n.Pipe.Cmds {
|
|
||||||
for _, arg := range cmd.Args {
|
|
||||||
if field, ok := arg.(*parse.FieldNode); ok {
|
|
||||||
if slices.Contains(field.Ident, "ToolCalls") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func ToolPrefix2(tmpl *gotmpl.Template) (string, bool) {
|
|
||||||
tokenText, ok := extractToolCallsTemplate(tmpl)
|
|
||||||
if !ok {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
tokenText = strings.TrimSpace(tokenText)
|
|
||||||
return tokenText, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func ToolPrefix(tmpl *gotmpl.Template) (string, bool) {
|
|
||||||
tokenText, ok := extractToolCallsTemplate(tmpl)
|
|
||||||
if !ok {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
tokenText = strings.TrimSpace(tokenText)
|
|
||||||
if tokenText == "" {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
first := strings.Fields(tokenText)[0]
|
|
||||||
|
|
||||||
start := -1
|
|
||||||
end := -1
|
|
||||||
for i, r := range tokenText {
|
|
||||||
if r == '<' || r == '[' {
|
|
||||||
start = i
|
|
||||||
}
|
|
||||||
if (r == '>' || r == ']') && start != -1 {
|
|
||||||
end = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if start != -1 && end != -1 {
|
|
||||||
// return the token including the [ or < and the ] or >
|
|
||||||
return tokenText[start : end+1], true
|
|
||||||
} else if start != -1 {
|
|
||||||
// get until the [ or < - in the case tag was not closed
|
|
||||||
return tokenText[:start], true
|
|
||||||
} else if end != -1 {
|
|
||||||
// get after the ] or > - in the case tag was not opened
|
|
||||||
return tokenText[end+1:], true
|
|
||||||
}
|
|
||||||
return first, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func ToolTemplate(m *Model) (*gotmpl.Template, bool) {
|
func ToolTemplate(m *Model) (*gotmpl.Template, bool) {
|
||||||
// create a subtree from the node that ranges over .ToolCalls
|
// create a subtree from the node that ranges over .ToolCalls
|
||||||
tmpl := m.Template.Subtree(func(n parse.Node) bool {
|
tmpl := m.Template.Subtree(func(n parse.Node) bool {
|
||||||
|
@ -38,6 +38,7 @@ import (
|
|||||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||||
"github.com/ollama/ollama/server/internal/registry"
|
"github.com/ollama/ollama/server/internal/registry"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
|
"github.com/ollama/ollama/tools"
|
||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
@ -1485,9 +1486,15 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
ch := make(chan any)
|
ch := make(chan any)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
var toolParser *ToolParser
|
// ! personally not a fan of this pattern
|
||||||
|
toolTemplate, ok := ToolTemplate(m)
|
||||||
|
if !ok {
|
||||||
|
slog.Error("tool template not found", "model", m.Name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var toolParser *tools.Parser
|
||||||
if len(req.Tools) > 0 {
|
if len(req.Tools) > 0 {
|
||||||
toolParser = NewToolParser(m)
|
toolParser = tools.NewParser(m.Template.Template, toolTemplate)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
@ -1521,18 +1528,18 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
// * However, we'd need a flag to indicate whether to send the response or not
|
// * However, we'd need a flag to indicate whether to send the response or not
|
||||||
// * happy to take whatever is more idiomatic
|
// * happy to take whatever is more idiomatic
|
||||||
switch toolParser.ParserState {
|
switch toolParser.ParserState {
|
||||||
case ToolCallAccumulate:
|
case tools.ToolCallAccumulate:
|
||||||
// tokens are accumulated in the tool parser
|
// tokens are accumulated in the tool parser
|
||||||
return
|
return
|
||||||
case ToolCallSendTokens:
|
case tools.ToolCallSendTokens:
|
||||||
// tokens are sent back in the response
|
// tokens are sent back in the response
|
||||||
case ToolCallSendPartial:
|
case tools.ToolCallSendPartial:
|
||||||
// tokens not needed for parsing are sent back in the response
|
// tokens not needed for parsing are sent back in the response
|
||||||
if len(leftover) > 0 {
|
if len(leftover) > 0 {
|
||||||
res.Message.Content = leftover
|
res.Message.Content = leftover
|
||||||
}
|
}
|
||||||
// ! state is needed as we need to not match on the other states
|
// ! state is needed as we need to not match on the other states
|
||||||
case ToolCallFound:
|
case tools.ToolCallFound:
|
||||||
res.Message.ToolCalls = toolCalls
|
res.Message.ToolCalls = toolCalls
|
||||||
res.Message.Content = ""
|
res.Message.Content = ""
|
||||||
}
|
}
|
||||||
|
@ -1,12 +1,15 @@
|
|||||||
package server
|
package tools
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
gotmpl "text/template"
|
gotmpl "text/template"
|
||||||
|
"text/template/parse"
|
||||||
|
|
||||||
jsonv2 "github.com/go-json-experiment/json"
|
jsonv2 "github.com/go-json-experiment/json"
|
||||||
jsontext "github.com/go-json-experiment/json/jsontext"
|
jsontext "github.com/go-json-experiment/json/jsontext"
|
||||||
@ -77,7 +80,7 @@ func (s State) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: simplify if possible
|
// TODO: simplify if possible
|
||||||
type ToolParser struct {
|
type Parser struct {
|
||||||
tmpl *gotmpl.Template
|
tmpl *gotmpl.Template
|
||||||
state State
|
state State
|
||||||
sb *strings.Builder
|
sb *strings.Builder
|
||||||
@ -90,7 +93,7 @@ type ToolParser struct {
|
|||||||
// ? move to a separate file
|
// ? move to a separate file
|
||||||
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
||||||
// Returns parsed tool calls, a boolean indicating if the JSON is incomplete, and a boolean indicating if the tool calls were found
|
// Returns parsed tool calls, a boolean indicating if the JSON is incomplete, and a boolean indicating if the tool calls were found
|
||||||
func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
||||||
fmt.Printf("attempting to parse JSON tool calls: input=%s\n", s)
|
fmt.Printf("attempting to parse JSON tool calls: input=%s\n", s)
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
@ -220,7 +223,7 @@ func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: clean up the boundary of internal and external state transitions
|
// TODO: clean up the boundary of internal and external state transitions
|
||||||
func (p *ToolParser) updateStateAfterJSONParse(ok bool, partial bool, tcs []api.ToolCall) {
|
func (p *Parser) updateStateAfterJSONParse(ok bool, partial bool, tcs []api.ToolCall) {
|
||||||
fmt.Printf("updating output state: ok=%v partial=%v tool_calls=%d current_state=%s\n", ok, partial, len(tcs), p.state)
|
fmt.Printf("updating output state: ok=%v partial=%v tool_calls=%d current_state=%s\n", ok, partial, len(tcs), p.state)
|
||||||
|
|
||||||
// state transition logic
|
// state transition logic
|
||||||
@ -252,7 +255,7 @@ func (p *ToolParser) updateStateAfterJSONParse(ok bool, partial bool, tcs []api.
|
|||||||
fmt.Printf("state updated: new_state=%s parser_state=%s\n", p.state, p.ParserState)
|
fmt.Printf("state updated: new_state=%s parser_state=%s\n", p.state, p.ParserState)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ToolParser) updateExternalState(tcs []api.ToolCall) {
|
func (p *Parser) updateExternalState(tcs []api.ToolCall) {
|
||||||
fmt.Printf("updating external state: current_state=%s tool_calls=%d\n", p.state, len(tcs))
|
fmt.Printf("updating external state: current_state=%s tool_calls=%d\n", p.state, len(tcs))
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
@ -283,7 +286,7 @@ func (p *ToolParser) updateExternalState(tcs []api.ToolCall) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// string, and if it has a prefix
|
// string, and if it has a prefix
|
||||||
func (p *ToolParser) checkPrefix(s string) (string, bool) {
|
func (p *Parser) checkPrefix(s string) (string, bool) {
|
||||||
fmt.Printf("checking prefix: input=%s prefix=%s\n", s, p.toolPrefix)
|
fmt.Printf("checking prefix: input=%s prefix=%s\n", s, p.toolPrefix)
|
||||||
|
|
||||||
if p.toolPrefix == "" {
|
if p.toolPrefix == "" {
|
||||||
@ -322,7 +325,7 @@ func (p *ToolParser) checkPrefix(s string) (string, bool) {
|
|||||||
// TODO: simplify the flow of this function
|
// TODO: simplify the flow of this function
|
||||||
// ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing.
|
// ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing.
|
||||||
// Returns tool calls, whether parsing is incomplete, and any errors.
|
// Returns tool calls, whether parsing is incomplete, and any errors.
|
||||||
func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, string) {
|
func (p *Parser) ParseToolCalls(s string) ([]api.ToolCall, string) {
|
||||||
fmt.Printf("parsing tool calls: input=%s current_state=%s\n", s, p.state)
|
fmt.Printf("parsing tool calls: input=%s current_state=%s\n", s, p.state)
|
||||||
|
|
||||||
p.sb.WriteString(s)
|
p.sb.WriteString(s)
|
||||||
@ -388,26 +391,144 @@ func suffixOverlap(s, delim string) int {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewToolParser(model *Model) *ToolParser {
|
// extractToolCallsTemplate finds the immediate following text after any IfNode containing ".ToolCalls"
|
||||||
// TODO: use new template parsing to get all tokens for the prefix
|
func extractToolCallsTemplate(tmpl *gotmpl.Template) (string, bool) {
|
||||||
templateToolPrefix, _ := ToolPrefix(model.Template.Template)
|
if tmpl == nil || tmpl.Tree == nil {
|
||||||
templateToolPrefix = strings.TrimSpace(templateToolPrefix)
|
slog.Debug("TextAfterToolCalls: template or tree is nil")
|
||||||
tmpl, ok := ToolTemplate(model)
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
var result string
|
||||||
|
var found bool
|
||||||
|
|
||||||
|
var walk func(nodes []parse.Node)
|
||||||
|
walk = func(nodes []parse.Node) {
|
||||||
|
for _, node := range nodes {
|
||||||
|
if found {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch n := node.(type) {
|
||||||
|
case *parse.IfNode:
|
||||||
|
if nodeContainsToolCalls(n) {
|
||||||
|
// Collect immediate TextNode(s) at start of IfNode's list
|
||||||
|
var sb strings.Builder
|
||||||
|
for _, innerNode := range n.List.Nodes {
|
||||||
|
if tn, ok := innerNode.(*parse.TextNode); ok {
|
||||||
|
sb.Write(tn.Text)
|
||||||
|
} else {
|
||||||
|
// Stop at first non-text node
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result = sb.String()
|
||||||
|
found = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Recurse into child nodes
|
||||||
|
walk(n.List.Nodes)
|
||||||
|
if n.ElseList != nil {
|
||||||
|
walk(n.ElseList.Nodes)
|
||||||
|
}
|
||||||
|
case *parse.ListNode:
|
||||||
|
walk(n.Nodes)
|
||||||
|
case *parse.RangeNode:
|
||||||
|
walk(n.List.Nodes)
|
||||||
|
if n.ElseList != nil {
|
||||||
|
walk(n.ElseList.Nodes)
|
||||||
|
}
|
||||||
|
case *parse.WithNode:
|
||||||
|
walk(n.List.Nodes)
|
||||||
|
if n.ElseList != nil {
|
||||||
|
walk(n.ElseList.Nodes)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Continue to next node
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if found {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
walk(tmpl.Tree.Root.Nodes)
|
||||||
|
return result, found
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to detect if a node's condition includes ".ToolCalls"
|
||||||
|
func nodeContainsToolCalls(n *parse.IfNode) bool {
|
||||||
|
for _, cmd := range n.Pipe.Cmds {
|
||||||
|
for _, arg := range cmd.Args {
|
||||||
|
if field, ok := arg.(*parse.FieldNode); ok {
|
||||||
|
if slices.Contains(field.Ident, "ToolCalls") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func ToolPrefix(tmpl *gotmpl.Template) (string, bool) {
|
||||||
|
tokenText, ok := extractToolCallsTemplate(tmpl)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
tokenText = strings.TrimSpace(tokenText)
|
||||||
|
if tokenText == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
first := strings.Fields(tokenText)[0]
|
||||||
|
|
||||||
|
start := -1
|
||||||
|
end := -1
|
||||||
|
for i, r := range tokenText {
|
||||||
|
if r == '<' || r == '[' {
|
||||||
|
start = i
|
||||||
|
}
|
||||||
|
if (r == '>' || r == ']') && start != -1 {
|
||||||
|
end = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if start != -1 && end != -1 {
|
||||||
|
// return the token including the [ or < and the ] or >
|
||||||
|
return tokenText[start : end+1], true
|
||||||
|
} else if start != -1 {
|
||||||
|
// get until the [ or < - in the case tag was not closed
|
||||||
|
return tokenText[:start], true
|
||||||
|
} else if end != -1 {
|
||||||
|
// get after the ] or > - in the case tag was not opened
|
||||||
|
return tokenText[end+1:], true
|
||||||
|
}
|
||||||
|
return first, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewParser(tmpl *gotmpl.Template, toolTemplate *gotmpl.Template) *Parser {
|
||||||
|
// TODO: use new template parsing to get all tokens for the prefix
|
||||||
|
if tmpl == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if toolTemplate == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
prefix, _ := ToolPrefix(tmpl)
|
||||||
|
prefix = strings.TrimSpace(prefix)
|
||||||
|
|
||||||
var state State
|
var state State
|
||||||
if templateToolPrefix == "" {
|
if prefix == "" {
|
||||||
state = GreedyToolNoPrefix
|
state = GreedyToolNoPrefix
|
||||||
} else {
|
} else {
|
||||||
state = GreedyToolWithPrefix
|
state = GreedyToolWithPrefix
|
||||||
}
|
}
|
||||||
fmt.Printf("creating new tool parser: prefix=%s initial_state=%s\n", templateToolPrefix, state)
|
fmt.Printf("creating new tool parser: prefix=%s initial_state=%s\n", prefix, state)
|
||||||
return &ToolParser{
|
return &Parser{
|
||||||
tmpl: tmpl,
|
tmpl: toolTemplate,
|
||||||
sb: &strings.Builder{},
|
sb: &strings.Builder{},
|
||||||
toolPrefix: templateToolPrefix,
|
toolPrefix: prefix,
|
||||||
state: state,
|
state: state,
|
||||||
ParserState: ToolCallAccumulate,
|
ParserState: ToolCallAccumulate,
|
||||||
}
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package server
|
package tools
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@ -6,8 +6,11 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
gotmpl "text/template"
|
||||||
|
"text/template/parse"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
@ -27,7 +30,7 @@ func readFile(t *testing.T, base, name string) *bytes.Buffer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestParseToolCalls(t *testing.T) {
|
func TestParseToolCalls(t *testing.T) {
|
||||||
p := filepath.Join("testdata", "tools")
|
p := filepath.Join("testdata")
|
||||||
t1 := api.ToolCall{
|
t1 := api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_current_weather",
|
Name: "get_current_weather",
|
||||||
@ -311,8 +314,12 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("parse", func(t *testing.T) {
|
t.Run("parse", func(t *testing.T) {
|
||||||
m := &Model{Template: tmpl}
|
// fmt.Printf("tmpl: %s\n", tmpl.Root.String())
|
||||||
tp := NewToolParser(m)
|
toolTemplate, ok := toolTemplateHelper(t, tmpl)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("tool template not found for model %s", tt.model)
|
||||||
|
}
|
||||||
|
tp := NewParser(tmpl.Template, toolTemplate)
|
||||||
got := []api.ToolCall{}
|
got := []api.ToolCall{}
|
||||||
var gotTokens strings.Builder
|
var gotTokens strings.Builder
|
||||||
|
|
||||||
@ -358,3 +365,25 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toolTemplateHelper(t *testing.T, tmpl *template.Template) (*gotmpl.Template, bool) {
|
||||||
|
// create a subtree from the node that ranges over .ToolCalls
|
||||||
|
|
||||||
|
tmpl2 := tmpl.Subtree(func(n parse.Node) bool {
|
||||||
|
if t, ok := n.(*parse.RangeNode); ok {
|
||||||
|
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
|
||||||
|
if tmpl2.Root != nil {
|
||||||
|
t.Log("tmpl2", tmpl2.Root.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if tmpl2 == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return tmpl2, true
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user