jsonv2 decoder
This commit is contained in:
parent
7f2f996cd6
commit
516a540df7
1
go.mod
1
go.mod
@ -35,6 +35,7 @@ require (
|
|||||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/go-json-experiment/json v0.0.0-20250417205406-170dfdcf87d1 // indirect
|
||||||
github.com/gogo/protobuf v1.3.2 // indirect
|
github.com/gogo/protobuf v1.3.2 // indirect
|
||||||
github.com/google/flatbuffers v24.3.25+incompatible // indirect
|
github.com/google/flatbuffers v24.3.25+incompatible // indirect
|
||||||
github.com/kr/text v0.2.0 // indirect
|
github.com/kr/text v0.2.0 // indirect
|
||||||
|
2
go.sum
2
go.sum
@ -69,6 +69,8 @@ github.com/go-fonts/latin-modern v0.2.0/go.mod h1:rQVLdDMK+mK1xscDwsqM5J8U2jrRa3
|
|||||||
github.com/go-fonts/liberation v0.1.1/go.mod h1:K6qoJYypsmfVjWg8KOVDQhLc8UDgIK2HYqyqAO9z7GY=
|
github.com/go-fonts/liberation v0.1.1/go.mod h1:K6qoJYypsmfVjWg8KOVDQhLc8UDgIK2HYqyqAO9z7GY=
|
||||||
github.com/go-fonts/stix v0.1.0/go.mod h1:w/c1f0ldAUlJmLBvlbkvVXLAD+tAMqobIIQpmnUIzUY=
|
github.com/go-fonts/stix v0.1.0/go.mod h1:w/c1f0ldAUlJmLBvlbkvVXLAD+tAMqobIIQpmnUIzUY=
|
||||||
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
|
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
|
||||||
|
github.com/go-json-experiment/json v0.0.0-20250417205406-170dfdcf87d1 h1:+VexzzkMLb1tnvpuQdGT/DicIRW7MN8ozsXqBMgp0Hk=
|
||||||
|
github.com/go-json-experiment/json v0.0.0-20250417205406-170dfdcf87d1/go.mod h1:TiCD2a1pcmjd7YnhGH0f/zKNcCD06B029pHhzV23c2M=
|
||||||
github.com/go-latex/latex v0.0.0-20210118124228-b3d85cf34e07/go.mod h1:CO1AlKB2CSIqUrmQPqA0gdRIlnLEY0gK5JGjh37zN5U=
|
github.com/go-latex/latex v0.0.0-20210118124228-b3d85cf34e07/go.mod h1:CO1AlKB2CSIqUrmQPqA0gdRIlnLEY0gK5JGjh37zN5U=
|
||||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||||
|
@ -210,7 +210,16 @@ func nodeContainsToolCalls(n *parse.IfNode) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func ToolToken(tmpl *gotmpl.Template) (string, bool) {
|
func ToolPrefix2(tmpl *gotmpl.Template) (string, bool) {
|
||||||
|
tokenText, ok := extractToolCallsTemplate(tmpl)
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
tokenText = strings.TrimSpace(tokenText)
|
||||||
|
return tokenText, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func ToolPrefix(tmpl *gotmpl.Template) (string, bool) {
|
||||||
tokenText, ok := extractToolCallsTemplate(tmpl)
|
tokenText, ok := extractToolCallsTemplate(tmpl)
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", false
|
return "", false
|
||||||
|
@ -80,7 +80,7 @@ func TestToolToken(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to parse template: %v", err)
|
t.Fatalf("failed to parse template: %v", err)
|
||||||
}
|
}
|
||||||
got, ok := ToolToken(tmpl)
|
got, ok := ToolPrefix(tmpl)
|
||||||
if got != tt.want {
|
if got != tt.want {
|
||||||
t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want)
|
t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want)
|
||||||
}
|
}
|
||||||
|
@ -21,7 +21,6 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
gotmpl "text/template"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-contrib/cors"
|
"github.com/gin-contrib/cors"
|
||||||
@ -1486,26 +1485,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
ch := make(chan any)
|
ch := make(chan any)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
var sb strings.Builder
|
// var sb strings.Builder
|
||||||
var toolCallIndex int = 0
|
var toolCallIndex int = 0
|
||||||
var templateToolToken string
|
var tp *ToolParser
|
||||||
var tmpl *gotmpl.Template
|
|
||||||
if len(req.Tools) > 0 {
|
if len(req.Tools) > 0 {
|
||||||
var ok bool
|
tp = NewToolParser(m)
|
||||||
templateToolToken, ok = ToolToken(m.Template.Template)
|
|
||||||
if !ok {
|
|
||||||
slog.Debug("no tool token found")
|
|
||||||
}
|
|
||||||
tmpl, ok = ToolTemplate(m)
|
|
||||||
if !ok {
|
|
||||||
slog.Debug("no tool template found")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
checkToolCall := false
|
|
||||||
if len(req.Tools) > 0 {
|
|
||||||
checkToolCall = true
|
|
||||||
}
|
|
||||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Images: images,
|
Images: images,
|
||||||
@ -1526,50 +1512,29 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if r.Done {
|
if r.Done {
|
||||||
if sb.Len() > 0 {
|
|
||||||
res.Message.Content = sb.String()
|
|
||||||
}
|
|
||||||
res.DoneReason = r.DoneReason.String()
|
res.DoneReason = r.DoneReason.String()
|
||||||
res.TotalDuration = time.Since(checkpointStart)
|
res.TotalDuration = time.Since(checkpointStart)
|
||||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
sb.WriteString(r.Content)
|
if len(req.Tools) > 0 && !tp.done {
|
||||||
if len(req.Tools) > 0 && checkToolCall {
|
fmt.Println("checking tool calls")
|
||||||
slog.Debug("parse tool calls", "content", sb.String(), "templateToolToken", templateToolToken)
|
toolCalls, ok := tp.ParseToolCalls(r.Content)
|
||||||
toolCalls, partial, err := ParseToolCalls(sb.String(), templateToolToken, tmpl)
|
if tp.state == PartialTool {
|
||||||
if err == nil {
|
fmt.Println("partial tool, returning")
|
||||||
if partial {
|
return
|
||||||
// circuit break to remove tool end token
|
}
|
||||||
if len(toolCalls) > 0 {
|
if ok && len(toolCalls) > 0 {
|
||||||
sb.Reset()
|
|
||||||
}
|
|
||||||
// If the tool call is partial, we need to wait for the next chunk
|
|
||||||
return
|
|
||||||
}
|
|
||||||
res.Message.ToolCalls = toolCalls
|
res.Message.ToolCalls = toolCalls
|
||||||
for i := range toolCalls {
|
for i := range toolCalls {
|
||||||
toolCalls[i].Function.Index = toolCallIndex
|
toolCalls[i].Function.Index = toolCallIndex
|
||||||
toolCallIndex++
|
toolCallIndex++
|
||||||
}
|
}
|
||||||
|
// Remove content when tool call is present
|
||||||
res.Message.Content = ""
|
res.Message.Content = ""
|
||||||
ch <- res
|
|
||||||
// Only way to have multiple calls is to have [] which is derived or provided
|
|
||||||
// This case occurs when the tool call is a json block - do not allow tool calls again
|
|
||||||
if templateToolToken == "" || (templateToolToken != "" && !strings.HasPrefix(sb.String(), templateToolToken)) {
|
|
||||||
checkToolCall = false
|
|
||||||
}
|
|
||||||
sb.Reset()
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there is no template tool token, we don't need to check for tool calls after the first chunk
|
|
||||||
if templateToolToken == "" {
|
|
||||||
checkToolCall = false
|
|
||||||
}
|
|
||||||
res.Message.Content = sb.String()
|
|
||||||
sb.Reset()
|
|
||||||
ch <- res
|
ch <- res
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
|
239
server/tools.go
239
server/tools.go
@ -2,50 +2,39 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"strings"
|
"strings"
|
||||||
gotmpl "text/template"
|
gotmpl "text/template"
|
||||||
|
|
||||||
|
jsonv2 "github.com/go-json-experiment/json"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
func parseObjects(s string) []map[string]any {
|
type State int
|
||||||
var objs []map[string]any
|
|
||||||
for offset := 0; offset < len(s); {
|
const (
|
||||||
var obj map[string]any
|
NoTool State = iota
|
||||||
decoder := json.NewDecoder(strings.NewReader(s[offset:]))
|
PartialTool
|
||||||
err := decoder.Decode(&obj)
|
ToolCall
|
||||||
switch {
|
)
|
||||||
case errors.Is(err, io.EOF), errors.Is(err, io.ErrUnexpectedEOF):
|
|
||||||
return objs
|
type ToolParser struct {
|
||||||
case err != nil:
|
tmpl *gotmpl.Template
|
||||||
var syntax *json.SyntaxError
|
state State
|
||||||
var unmarshalType *json.UnmarshalTypeError
|
sb *strings.Builder
|
||||||
switch {
|
toolPrefix string
|
||||||
case errors.As(err, &syntax):
|
done bool
|
||||||
offset += int(syntax.Offset)
|
|
||||||
continue
|
|
||||||
case errors.As(err, &unmarshalType):
|
|
||||||
offset += int(unmarshalType.Offset)
|
|
||||||
continue
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
offset += int(decoder.InputOffset())
|
|
||||||
objs = append(objs, obj)
|
|
||||||
}
|
|
||||||
return objs
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
||||||
// Returns parsed tool calls and a boolean indicating if the JSON is incomplete
|
// Returns parsed tool calls and a boolean indicating if the JSON is incomplete
|
||||||
func parseJSONToolCalls(tmpl *gotmpl.Template, s string) ([]api.ToolCall, bool) {
|
func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
|
if err := p.tmpl.Execute(&b, map[string][]api.ToolCall{
|
||||||
"ToolCalls": {
|
"ToolCalls": {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
@ -57,35 +46,18 @@ func parseJSONToolCalls(tmpl *gotmpl.Template, s string) ([]api.ToolCall, bool)
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return nil, false
|
return nil, false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
templateObjects := parseObjects(b.String())
|
// slog.Debug("template", "template", b.String())
|
||||||
if len(templateObjects) == 0 {
|
|
||||||
return nil, false
|
// ! this can be either a map or an array
|
||||||
|
var temp any
|
||||||
|
err := jsonv2.Unmarshal(b.Bytes(), &temp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// find the keys that correspond to the name and arguments fields
|
|
||||||
var name, arguments string
|
|
||||||
for k, v := range templateObjects[0] {
|
|
||||||
switch v.(type) {
|
|
||||||
case string:
|
|
||||||
name = k
|
|
||||||
case map[string]any:
|
|
||||||
arguments = k
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if name == "" || arguments == "" {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
responseObjects := parseObjects(s)
|
|
||||||
if len(responseObjects) == 0 {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// collect all nested objects
|
|
||||||
var collect func(any) []map[string]any
|
var collect func(any) []map[string]any
|
||||||
collect = func(obj any) (all []map[string]any) {
|
collect = func(obj any) (all []map[string]any) {
|
||||||
switch o := obj.(type) {
|
switch o := obj.(type) {
|
||||||
@ -103,16 +75,63 @@ func parseJSONToolCalls(tmpl *gotmpl.Template, s string) ([]api.ToolCall, bool)
|
|||||||
return all
|
return all
|
||||||
}
|
}
|
||||||
|
|
||||||
var objs []map[string]any
|
var templateObjects []map[string]any
|
||||||
for _, p := range responseObjects {
|
switch t := temp.(type) {
|
||||||
objs = append(objs, collect(p)...)
|
case map[string]any:
|
||||||
|
templateObjects = []map[string]any{t}
|
||||||
|
case []map[string]any:
|
||||||
|
templateObjects = t
|
||||||
|
// ! fallback?
|
||||||
|
case []any:
|
||||||
|
templateObjects = collect(t)
|
||||||
}
|
}
|
||||||
|
if len(templateObjects) == 0 {
|
||||||
|
return nil, false, false
|
||||||
|
}
|
||||||
|
// fmt.Println("template objects", templateObjects)
|
||||||
|
|
||||||
|
// find the keys that correspond to the name and arguments fields
|
||||||
|
var name, arguments string
|
||||||
|
for k, v := range templateObjects[0] {
|
||||||
|
switch v.(type) {
|
||||||
|
case string:
|
||||||
|
name = k
|
||||||
|
case map[string]any:
|
||||||
|
arguments = k
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if name == "" || arguments == "" {
|
||||||
|
return nil, false, false
|
||||||
|
}
|
||||||
|
var responseObjects any
|
||||||
|
err = jsonv2.Unmarshal([]byte(s), &responseObjects)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, io.ErrUnexpectedEOF) || err.Error() == "unexpected end of JSON input" {
|
||||||
|
fmt.Println("Detected partial or incomplete JSON.")
|
||||||
|
fmt.Println("state", p.state)
|
||||||
|
return nil, true, false
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Other error: %v\n", err)
|
||||||
|
fmt.Println("exiting", p.state)
|
||||||
|
return nil, false, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var objs []map[string]any
|
||||||
|
objs = append(objs, collect(responseObjects)...)
|
||||||
|
if len(objs) == 0 {
|
||||||
|
return nil, false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("collected objects", "count", len(objs))
|
||||||
|
|
||||||
var toolCalls []api.ToolCall
|
var toolCalls []api.ToolCall
|
||||||
for _, kv := range objs {
|
for _, kv := range objs {
|
||||||
n, nok := kv[name].(string)
|
n, nok := kv[name].(string)
|
||||||
a, aok := kv[arguments].(map[string]any)
|
a, aok := kv[arguments].(map[string]any)
|
||||||
if nok && aok {
|
if nok && aok {
|
||||||
|
slog.Debug("found valid tool call", "name", n)
|
||||||
toolCalls = append(toolCalls, api.ToolCall{
|
toolCalls = append(toolCalls, api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: n,
|
Name: n,
|
||||||
@ -122,54 +141,82 @@ func parseJSONToolCalls(tmpl *gotmpl.Template, s string) ([]api.ToolCall, bool)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return toolCalls, len(toolCalls) > 0
|
slog.Debug("parsed tool calls", "count", len(toolCalls))
|
||||||
}
|
return toolCalls, len(toolCalls) > 0, true
|
||||||
|
|
||||||
// routeToolParsing is a helper function that routes what kind of tool parsing to use
|
|
||||||
func routeToolParsing(s string, tmpl *gotmpl.Template) ([]api.ToolCall, bool, bool) {
|
|
||||||
if strings.HasPrefix(s, "[{") || strings.HasPrefix(s, "```") || strings.HasPrefix(s, "{") {
|
|
||||||
if toolCalls, ok := parseJSONToolCalls(tmpl, s); ok {
|
|
||||||
return toolCalls, false, true
|
|
||||||
}
|
|
||||||
// in the case the JSON never finishes, the acuumulated content should be sent downstream
|
|
||||||
return nil, true, true
|
|
||||||
}
|
|
||||||
// TODO(parthsareen): add python tool call support
|
|
||||||
return nil, false, false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing.
|
// ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing.
|
||||||
// Returns tool calls, whether parsing is incomplete, and any errors.
|
// Returns tool calls, whether parsing is incomplete, and any errors.
|
||||||
func ParseToolCalls(s string, toolToken string, tmpl *gotmpl.Template) ([]api.ToolCall, bool, error) {
|
func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, bool) {
|
||||||
if tmpl == nil {
|
p.sb.WriteString(s)
|
||||||
return nil, false, fmt.Errorf("no template provided")
|
s = p.sb.String()
|
||||||
}
|
|
||||||
s = strings.TrimSpace(s)
|
s = strings.TrimSpace(s)
|
||||||
|
slog.Debug("parse tool calls", "content", s)
|
||||||
|
|
||||||
if len(s) == 0 {
|
if len(s) == 0 {
|
||||||
return nil, false, fmt.Errorf("empty input string")
|
return nil, false
|
||||||
}
|
}
|
||||||
if toolToken != "" {
|
hasPrefix := false
|
||||||
if strings.HasPrefix(s, toolToken) {
|
if p.toolPrefix != "" {
|
||||||
s = strings.TrimSpace(s[len(toolToken):])
|
if strings.HasPrefix(s, p.toolPrefix) {
|
||||||
tc, _, ok := routeToolParsing(s, tmpl)
|
s = strings.TrimSpace(s[len(p.toolPrefix):])
|
||||||
if len(tc) == 0 || !ok {
|
slog.Debug("tool prefix", "prefix", p.toolPrefix, "content", s)
|
||||||
return nil, true, nil
|
p.state = PartialTool
|
||||||
}
|
hasPrefix = true
|
||||||
return tc, false, nil
|
|
||||||
// Special token end case
|
// Special token end case
|
||||||
} else if strings.HasSuffix(s, toolToken[2:]) {
|
} else if strings.HasSuffix(s, p.toolPrefix[2:]) {
|
||||||
tc := api.ToolCall{
|
p.state = PartialTool
|
||||||
Function: api.ToolCallFunction{
|
p.sb.Reset()
|
||||||
Name: toolToken,
|
slog.Debug("setting to no tool", "content", s)
|
||||||
},
|
return nil, false
|
||||||
}
|
|
||||||
return []api.ToolCall{tc}, true, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
tcs, partial, ok := p.parseJSONToolCalls(s)
|
||||||
|
|
||||||
tc, partial, ok := routeToolParsing(s, tmpl)
|
// TODO: figure out how to return the remaining string if not partial anymore
|
||||||
if !ok {
|
// update state
|
||||||
return nil, false, fmt.Errorf("failed to parse tool calls for input: %q", s)
|
switch {
|
||||||
|
case !ok && !partial && hasPrefix:
|
||||||
|
p.state = PartialTool
|
||||||
|
case !ok && !partial:
|
||||||
|
p.state = NoTool
|
||||||
|
case !ok && partial:
|
||||||
|
p.state = PartialTool
|
||||||
|
case len(tcs) > 0:
|
||||||
|
p.state = ToolCall
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.state == NoTool || p.state == ToolCall {
|
||||||
|
slog.Debug("resetting string builder", "state", p.state)
|
||||||
|
p.sb.Reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("returning tool calls", "tool calls", tcs)
|
||||||
|
fmt.Println("end state", p.state)
|
||||||
|
if p.toolPrefix == "" {
|
||||||
|
p.done = true
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("len tcs", len(tcs))
|
||||||
|
return tcs, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewToolParser(model *Model) *ToolParser {
|
||||||
|
templateToolPrefix, _ := ToolPrefix(model.Template.Template)
|
||||||
|
slog.Debug("tool prefix", "prefix", templateToolPrefix)
|
||||||
|
tmpl, ok := ToolTemplate(model)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ToolParser{
|
||||||
|
tmpl: tmpl,
|
||||||
|
sb: &strings.Builder{},
|
||||||
|
toolPrefix: templateToolPrefix,
|
||||||
|
done: false,
|
||||||
}
|
}
|
||||||
return tc, partial, nil
|
|
||||||
}
|
}
|
||||||
|
@ -149,9 +149,10 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
// TODO: fix the spacing issue
|
||||||
name: "qwen with single tool call",
|
name: "qwen with single tool call",
|
||||||
model: "qwen2.5-coder",
|
model: "qwen2.5-coder",
|
||||||
output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
|
output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||||
token: "<tool_call>",
|
token: "<tool_call>",
|
||||||
expected: []api.ToolCall{t1},
|
expected: []api.ToolCall{t1},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
@ -185,7 +186,7 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range cases {
|
for _, tt := range cases {
|
||||||
t.Run(tt.model, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
|
tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@ -204,25 +205,17 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("parse", func(t *testing.T) {
|
t.Run("parse", func(t *testing.T) {
|
||||||
m := &Model{Template: tmpl}
|
m := &Model{Template: tmpl}
|
||||||
tmpl, ok := ToolTemplate(m)
|
tp := NewToolParser(m)
|
||||||
if !ok {
|
|
||||||
t.Fatal("no tool template found")
|
|
||||||
}
|
|
||||||
got := []api.ToolCall{}
|
got := []api.ToolCall{}
|
||||||
tokens := strings.Fields(tt.output)
|
|
||||||
sb := strings.Builder{}
|
|
||||||
success := false
|
success := false
|
||||||
|
tokens := strings.Fields(tt.output)
|
||||||
for _, tok := range tokens {
|
for _, tok := range tokens {
|
||||||
sb.WriteString(" " + tok)
|
s := " " + tok
|
||||||
toolCalls, partial, err := ParseToolCalls(sb.String(), tt.token, tmpl)
|
toolCalls, ok := tp.ParseToolCalls(s)
|
||||||
if err == nil {
|
if ok {
|
||||||
success = true
|
success = true
|
||||||
}
|
}
|
||||||
if partial {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
got = append(got, toolCalls...)
|
got = append(got, toolCalls...)
|
||||||
sb.Reset()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !tt.wantErr {
|
if !tt.wantErr {
|
||||||
@ -237,45 +230,3 @@ func TestParseToolCalls(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseObjects(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
input string
|
|
||||||
want []map[string]any
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
input: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
|
||||||
want: []map[string]any{
|
|
||||||
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
|
|
||||||
{"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, Canada"}},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input: `<some_token>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </toolcall>`,
|
|
||||||
want: []map[string]any{
|
|
||||||
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input: `<some_token>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </toolcall> <toolcall>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, ON"}} </toolcall>`,
|
|
||||||
want: []map[string]any{
|
|
||||||
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
|
|
||||||
{"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, ON"}},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input: `{"name": "get_current_weather", "arguments": `,
|
|
||||||
want: nil,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.input, func(t *testing.T) {
|
|
||||||
got := parseObjects(tc.input)
|
|
||||||
|
|
||||||
if diff := cmp.Diff(got, tc.want); diff != "" {
|
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user