jsonv2 decoder

This commit is contained in:
ParthSareen 2025-05-05 17:25:35 -07:00
parent 7f2f996cd6
commit 516a540df7
7 changed files with 177 additions and 202 deletions

1
go.mod
View File

@ -35,6 +35,7 @@ require (
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-json-experiment/json v0.0.0-20250417205406-170dfdcf87d1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/flatbuffers v24.3.25+incompatible // indirect
github.com/kr/text v0.2.0 // indirect

2
go.sum
View File

@ -69,6 +69,8 @@ github.com/go-fonts/latin-modern v0.2.0/go.mod h1:rQVLdDMK+mK1xscDwsqM5J8U2jrRa3
github.com/go-fonts/liberation v0.1.1/go.mod h1:K6qoJYypsmfVjWg8KOVDQhLc8UDgIK2HYqyqAO9z7GY=
github.com/go-fonts/stix v0.1.0/go.mod h1:w/c1f0ldAUlJmLBvlbkvVXLAD+tAMqobIIQpmnUIzUY=
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
github.com/go-json-experiment/json v0.0.0-20250417205406-170dfdcf87d1 h1:+VexzzkMLb1tnvpuQdGT/DicIRW7MN8ozsXqBMgp0Hk=
github.com/go-json-experiment/json v0.0.0-20250417205406-170dfdcf87d1/go.mod h1:TiCD2a1pcmjd7YnhGH0f/zKNcCD06B029pHhzV23c2M=
github.com/go-latex/latex v0.0.0-20210118124228-b3d85cf34e07/go.mod h1:CO1AlKB2CSIqUrmQPqA0gdRIlnLEY0gK5JGjh37zN5U=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=

View File

@ -210,7 +210,16 @@ func nodeContainsToolCalls(n *parse.IfNode) bool {
return false
}
func ToolToken(tmpl *gotmpl.Template) (string, bool) {
func ToolPrefix2(tmpl *gotmpl.Template) (string, bool) {
tokenText, ok := extractToolCallsTemplate(tmpl)
if !ok {
return "", false
}
tokenText = strings.TrimSpace(tokenText)
return tokenText, true
}
func ToolPrefix(tmpl *gotmpl.Template) (string, bool) {
tokenText, ok := extractToolCallsTemplate(tmpl)
if !ok {
return "", false

View File

@ -80,7 +80,7 @@ func TestToolToken(t *testing.T) {
if err != nil {
t.Fatalf("failed to parse template: %v", err)
}
got, ok := ToolToken(tmpl)
got, ok := ToolPrefix(tmpl)
if got != tt.want {
t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want)
}

View File

@ -21,7 +21,6 @@ import (
"slices"
"strings"
"syscall"
gotmpl "text/template"
"time"
"github.com/gin-contrib/cors"
@ -1486,26 +1485,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
ch := make(chan any)
go func() {
defer close(ch)
var sb strings.Builder
// var sb strings.Builder
var toolCallIndex int = 0
var templateToolToken string
var tmpl *gotmpl.Template
var tp *ToolParser
if len(req.Tools) > 0 {
var ok bool
templateToolToken, ok = ToolToken(m.Template.Template)
if !ok {
slog.Debug("no tool token found")
}
tmpl, ok = ToolTemplate(m)
if !ok {
slog.Debug("no tool template found")
}
tp = NewToolParser(m)
}
checkToolCall := false
if len(req.Tools) > 0 {
checkToolCall = true
}
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
@ -1526,50 +1512,29 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
if r.Done {
if sb.Len() > 0 {
res.Message.Content = sb.String()
}
res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
sb.WriteString(r.Content)
if len(req.Tools) > 0 && checkToolCall {
slog.Debug("parse tool calls", "content", sb.String(), "templateToolToken", templateToolToken)
toolCalls, partial, err := ParseToolCalls(sb.String(), templateToolToken, tmpl)
if err == nil {
if partial {
// circuit break to remove tool end token
if len(toolCalls) > 0 {
sb.Reset()
}
// If the tool call is partial, we need to wait for the next chunk
return
}
if len(req.Tools) > 0 && !tp.done {
fmt.Println("checking tool calls")
toolCalls, ok := tp.ParseToolCalls(r.Content)
if tp.state == PartialTool {
fmt.Println("partial tool, returning")
return
}
if ok && len(toolCalls) > 0 {
res.Message.ToolCalls = toolCalls
for i := range toolCalls {
toolCalls[i].Function.Index = toolCallIndex
toolCallIndex++
}
// Remove content when tool call is present
res.Message.Content = ""
ch <- res
// Only way to have multiple calls is to have [] which is derived or provided
// This case occurs when the tool call is a json block - do not allow tool calls again
if templateToolToken == "" || (templateToolToken != "" && !strings.HasPrefix(sb.String(), templateToolToken)) {
checkToolCall = false
}
sb.Reset()
return
}
}
// If there is no template tool token, we don't need to check for tool calls after the first chunk
if templateToolToken == "" {
checkToolCall = false
}
res.Message.Content = sb.String()
sb.Reset()
ch <- res
}); err != nil {
ch <- gin.H{"error": err.Error()}

View File

@ -2,50 +2,39 @@ package server
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"strings"
gotmpl "text/template"
jsonv2 "github.com/go-json-experiment/json"
"github.com/ollama/ollama/api"
)
func parseObjects(s string) []map[string]any {
var objs []map[string]any
for offset := 0; offset < len(s); {
var obj map[string]any
decoder := json.NewDecoder(strings.NewReader(s[offset:]))
err := decoder.Decode(&obj)
switch {
case errors.Is(err, io.EOF), errors.Is(err, io.ErrUnexpectedEOF):
return objs
case err != nil:
var syntax *json.SyntaxError
var unmarshalType *json.UnmarshalTypeError
switch {
case errors.As(err, &syntax):
offset += int(syntax.Offset)
continue
case errors.As(err, &unmarshalType):
offset += int(unmarshalType.Offset)
continue
default:
return nil
}
}
offset += int(decoder.InputOffset())
objs = append(objs, obj)
}
return objs
type State int
const (
NoTool State = iota
PartialTool
ToolCall
)
type ToolParser struct {
tmpl *gotmpl.Template
state State
sb *strings.Builder
toolPrefix string
done bool
}
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
// Returns parsed tool calls and a boolean indicating if the JSON is incomplete
func parseJSONToolCalls(tmpl *gotmpl.Template, s string) ([]api.ToolCall, bool) {
func (p *ToolParser) parseJSONToolCalls(s string) ([]api.ToolCall, bool, bool) {
var b bytes.Buffer
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
if err := p.tmpl.Execute(&b, map[string][]api.ToolCall{
"ToolCalls": {
{
Function: api.ToolCallFunction{
@ -57,35 +46,18 @@ func parseJSONToolCalls(tmpl *gotmpl.Template, s string) ([]api.ToolCall, bool)
},
},
}); err != nil {
return nil, false
return nil, false, false
}
templateObjects := parseObjects(b.String())
if len(templateObjects) == 0 {
return nil, false
// slog.Debug("template", "template", b.String())
// ! this can be either a map or an array
var temp any
err := jsonv2.Unmarshal(b.Bytes(), &temp)
if err != nil {
return nil, false, false
}
// find the keys that correspond to the name and arguments fields
var name, arguments string
for k, v := range templateObjects[0] {
switch v.(type) {
case string:
name = k
case map[string]any:
arguments = k
}
}
if name == "" || arguments == "" {
return nil, false
}
responseObjects := parseObjects(s)
if len(responseObjects) == 0 {
return nil, false
}
// collect all nested objects
var collect func(any) []map[string]any
collect = func(obj any) (all []map[string]any) {
switch o := obj.(type) {
@ -103,16 +75,63 @@ func parseJSONToolCalls(tmpl *gotmpl.Template, s string) ([]api.ToolCall, bool)
return all
}
var objs []map[string]any
for _, p := range responseObjects {
objs = append(objs, collect(p)...)
var templateObjects []map[string]any
switch t := temp.(type) {
case map[string]any:
templateObjects = []map[string]any{t}
case []map[string]any:
templateObjects = t
// ! fallback?
case []any:
templateObjects = collect(t)
}
if len(templateObjects) == 0 {
return nil, false, false
}
// fmt.Println("template objects", templateObjects)
// find the keys that correspond to the name and arguments fields
var name, arguments string
for k, v := range templateObjects[0] {
switch v.(type) {
case string:
name = k
case map[string]any:
arguments = k
}
}
if name == "" || arguments == "" {
return nil, false, false
}
var responseObjects any
err = jsonv2.Unmarshal([]byte(s), &responseObjects)
if err != nil {
if errors.Is(err, io.ErrUnexpectedEOF) || err.Error() == "unexpected end of JSON input" {
fmt.Println("Detected partial or incomplete JSON.")
fmt.Println("state", p.state)
return nil, true, false
} else {
fmt.Printf("Other error: %v\n", err)
fmt.Println("exiting", p.state)
return nil, false, false
}
}
var objs []map[string]any
objs = append(objs, collect(responseObjects)...)
if len(objs) == 0 {
return nil, false, false
}
slog.Debug("collected objects", "count", len(objs))
var toolCalls []api.ToolCall
for _, kv := range objs {
n, nok := kv[name].(string)
a, aok := kv[arguments].(map[string]any)
if nok && aok {
slog.Debug("found valid tool call", "name", n)
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: n,
@ -122,54 +141,82 @@ func parseJSONToolCalls(tmpl *gotmpl.Template, s string) ([]api.ToolCall, bool)
}
}
return toolCalls, len(toolCalls) > 0
}
// routeToolParsing is a helper function that routes what kind of tool parsing to use
func routeToolParsing(s string, tmpl *gotmpl.Template) ([]api.ToolCall, bool, bool) {
if strings.HasPrefix(s, "[{") || strings.HasPrefix(s, "```") || strings.HasPrefix(s, "{") {
if toolCalls, ok := parseJSONToolCalls(tmpl, s); ok {
return toolCalls, false, true
}
// in the case the JSON never finishes, the acuumulated content should be sent downstream
return nil, true, true
}
// TODO(parthsareen): add python tool call support
return nil, false, false
slog.Debug("parsed tool calls", "count", len(toolCalls))
return toolCalls, len(toolCalls) > 0, true
}
// ParseToolCalls extracts tool calls from a string using a tool token prefix or direct JSON parsing.
// Returns tool calls, whether parsing is incomplete, and any errors.
func ParseToolCalls(s string, toolToken string, tmpl *gotmpl.Template) ([]api.ToolCall, bool, error) {
if tmpl == nil {
return nil, false, fmt.Errorf("no template provided")
}
func (p *ToolParser) ParseToolCalls(s string) ([]api.ToolCall, bool) {
p.sb.WriteString(s)
s = p.sb.String()
s = strings.TrimSpace(s)
slog.Debug("parse tool calls", "content", s)
if len(s) == 0 {
return nil, false, fmt.Errorf("empty input string")
return nil, false
}
if toolToken != "" {
if strings.HasPrefix(s, toolToken) {
s = strings.TrimSpace(s[len(toolToken):])
tc, _, ok := routeToolParsing(s, tmpl)
if len(tc) == 0 || !ok {
return nil, true, nil
}
return tc, false, nil
hasPrefix := false
if p.toolPrefix != "" {
if strings.HasPrefix(s, p.toolPrefix) {
s = strings.TrimSpace(s[len(p.toolPrefix):])
slog.Debug("tool prefix", "prefix", p.toolPrefix, "content", s)
p.state = PartialTool
hasPrefix = true
// Special token end case
} else if strings.HasSuffix(s, toolToken[2:]) {
tc := api.ToolCall{
Function: api.ToolCallFunction{
Name: toolToken,
},
}
return []api.ToolCall{tc}, true, nil
} else if strings.HasSuffix(s, p.toolPrefix[2:]) {
p.state = PartialTool
p.sb.Reset()
slog.Debug("setting to no tool", "content", s)
return nil, false
}
}
tcs, partial, ok := p.parseJSONToolCalls(s)
tc, partial, ok := routeToolParsing(s, tmpl)
if !ok {
return nil, false, fmt.Errorf("failed to parse tool calls for input: %q", s)
// TODO: figure out how to return the remaining string if not partial anymore
// update state
switch {
case !ok && !partial && hasPrefix:
p.state = PartialTool
case !ok && !partial:
p.state = NoTool
case !ok && partial:
p.state = PartialTool
case len(tcs) > 0:
p.state = ToolCall
}
if p.state == NoTool || p.state == ToolCall {
slog.Debug("resetting string builder", "state", p.state)
p.sb.Reset()
}
if !ok {
return nil, false
}
slog.Debug("returning tool calls", "tool calls", tcs)
fmt.Println("end state", p.state)
if p.toolPrefix == "" {
p.done = true
}
fmt.Println("len tcs", len(tcs))
return tcs, true
}
func NewToolParser(model *Model) *ToolParser {
templateToolPrefix, _ := ToolPrefix(model.Template.Template)
slog.Debug("tool prefix", "prefix", templateToolPrefix)
tmpl, ok := ToolTemplate(model)
if !ok {
return nil
}
return &ToolParser{
tmpl: tmpl,
sb: &strings.Builder{},
toolPrefix: templateToolPrefix,
done: false,
}
return tc, partial, nil
}

View File

@ -149,9 +149,10 @@ func TestParseToolCalls(t *testing.T) {
wantErr: false,
},
{
// TODO: fix the spacing issue
name: "qwen with single tool call",
model: "qwen2.5-coder",
output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
token: "<tool_call>",
expected: []api.ToolCall{t1},
wantErr: false,
@ -185,7 +186,7 @@ func TestParseToolCalls(t *testing.T) {
}
for _, tt := range cases {
t.Run(tt.model, func(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
if err != nil {
t.Fatal(err)
@ -204,25 +205,17 @@ func TestParseToolCalls(t *testing.T) {
t.Run("parse", func(t *testing.T) {
m := &Model{Template: tmpl}
tmpl, ok := ToolTemplate(m)
if !ok {
t.Fatal("no tool template found")
}
tp := NewToolParser(m)
got := []api.ToolCall{}
tokens := strings.Fields(tt.output)
sb := strings.Builder{}
success := false
tokens := strings.Fields(tt.output)
for _, tok := range tokens {
sb.WriteString(" " + tok)
toolCalls, partial, err := ParseToolCalls(sb.String(), tt.token, tmpl)
if err == nil {
s := " " + tok
toolCalls, ok := tp.ParseToolCalls(s)
if ok {
success = true
}
if partial {
continue
}
got = append(got, toolCalls...)
sb.Reset()
}
if !tt.wantErr {
@ -237,45 +230,3 @@ func TestParseToolCalls(t *testing.T) {
})
}
}
func TestParseObjects(t *testing.T) {
tests := []struct {
input string
want []map[string]any
}{
{
input: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
want: []map[string]any{
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
{"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, Canada"}},
},
},
{
input: `<some_token>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </toolcall>`,
want: []map[string]any{
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
},
},
{
input: `<some_token>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </toolcall> <toolcall>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, ON"}} </toolcall>`,
want: []map[string]any{
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
{"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, ON"}},
},
},
{
input: `{"name": "get_current_weather", "arguments": `,
want: nil,
},
}
for _, tc := range tests {
t.Run(tc.input, func(t *testing.T) {
got := parseObjects(tc.input)
if diff := cmp.Diff(got, tc.want); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
}