From f5872a097cd5f1c7c38284cdb200b0ffbe0201f3 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Wed, 23 Apr 2025 15:45:35 -0700 Subject: [PATCH] checkpoint --- server/model.go | 424 ++++++++++++++++++++++++++++++++++++++--------- server/routes.go | 111 ++++++++++--- 2 files changed, 434 insertions(+), 101 deletions(-) diff --git a/server/model.go b/server/model.go index 4aac1e43f..d37a4a553 100644 --- a/server/model.go +++ b/server/model.go @@ -197,6 +197,8 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { return nil, false } + slog.Debug("parseToolCalls: template objects", "objects", templateObjects) + // find the keys that correspond to the name and arguments fields var name, arguments string for k, v := range templateObjects[0] { @@ -257,63 +259,196 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { return toolCalls, len(toolCalls) > 0 } -func (m *Model) ParseToolCallsNew(s string) ([]api.ToolCall, bool) { - // Parse both Python function calls and JSON function calls into ToolCall structs - // Example inputs: - // Python: func(a=2, b=2) - // JSON: {"function": {"name": "func", "arguments": {"a": 2, "b": 2}}} - // JSON array: [{"name": "func", "arguments": {"a": 2}}] +// ToolCallFormat represents different possible formats for tool calls +type toolCallFormat struct { + // Direct format + Name string `json:"name,omitempty"` + Arguments map[string]any `json:"arguments,omitempty"` - slog.Debug("parsing function calls", "input", s) + // Command-r-plus format + ToolName string `json:"tool_name,omitempty"` + Parameters map[string]any `json:"parameters,omitempty"` - // Try JSON parsing first - if strings.HasPrefix(strings.TrimSpace(s), "[") { - // Try parsing as JSON array - var jsonArray []map[string]any - if err := json.Unmarshal([]byte(s), &jsonArray); err == nil { - var toolCalls []api.ToolCall - for _, obj := range jsonArray { - if calls, ok := parseJSONToolCalls(obj); ok { - toolCalls = append(toolCalls, calls...) - } + // Function format + Function *struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments,omitempty"` + Parameters map[string]any `json:"parameters,omitempty"` + } `json:"function,omitempty"` + + // Xlam format + ToolCalls []toolCallFormat `json:"tool_calls,omitempty"` +} + +func parseJSONToolCalls(obj map[string]any) ([]api.ToolCall, bool) { + // Helper to convert any to []any safely + toArray := func(v any) []any { + if arr, ok := v.([]any); ok { + return arr + } + return nil + } + + // Convert a single format to a tool call + makeToolCall := func(f toolCallFormat) (api.ToolCall, bool) { + switch { + case f.Name != "" && f.Arguments != nil: + return api.ToolCall{ + Function: api.ToolCallFunction{ + Name: f.Name, + Arguments: f.Arguments, + }, + }, true + case f.Name != "" && f.Parameters != nil: // Handle parameters field + return api.ToolCall{ + Function: api.ToolCallFunction{ + Name: f.Name, + Arguments: f.Parameters, + }, + }, true + case f.ToolName != "" && f.Parameters != nil: + return api.ToolCall{ + Function: api.ToolCallFunction{ + Name: f.ToolName, + Arguments: f.Parameters, + }, + }, true + case f.Function != nil && f.Function.Name != "": + args := f.Function.Arguments + if args == nil { + args = f.Function.Parameters } - if len(toolCalls) > 0 { - return toolCalls, true + if args != nil { + return api.ToolCall{ + Function: api.ToolCallFunction{ + Name: f.Function.Name, + Arguments: args, + }, + }, true } } - } else { - // Try parsing as single JSON object - var jsonObj map[string]any - if err := json.Unmarshal([]byte(s), &jsonObj); err == nil { - if toolCalls, ok := parseJSONToolCalls(jsonObj); ok { - return toolCalls, true + return api.ToolCall{}, false + } + + // Try parsing as array first + if arr := toArray(obj); arr != nil { + var calls []api.ToolCall + for _, item := range arr { + if itemMap, ok := item.(map[string]any); ok { + var format toolCallFormat + data, _ := json.Marshal(itemMap) + if err := json.Unmarshal(data, &format); err == nil { + if call, ok := makeToolCall(format); ok { + calls = append(calls, call) + } + } } } + if len(calls) > 0 { + return calls, true + } } - // Fall back to Python-style parsing - re := regexp.MustCompile(`(\w+)\((.*?)\)`) - matches := re.FindAllStringSubmatch(s, -1) - - if len(matches) == 0 { - slog.Debug("no function calls found") + // Try parsing as single object + var format toolCallFormat + data, _ := json.Marshal(obj) + if err := json.Unmarshal(data, &format); err != nil { return nil, false } - slog.Debug("found function calls", "matches", len(matches)) + // Handle xlam format (tool_calls array) + if len(format.ToolCalls) > 0 { + var calls []api.ToolCall + for _, f := range format.ToolCalls { + if call, ok := makeToolCall(f); ok { + calls = append(calls, call) + } + } + if len(calls) > 0 { + return calls, true + } + } - var toolCalls []api.ToolCall - for i, match := range matches { - name := match[1] - args := match[2] + // Try as single tool call + if call, ok := makeToolCall(format); ok { + return []api.ToolCall{call}, true + } - slog.Debug("parsing function call", "index", i, "name", name, "args", args) + return nil, false +} + +func (m *Model) GetToolCallFormat(s string) (string, string, bool) { + // Try to detect the tool call format from the model's template + tmpl := m.Template.Subtree(func(n parse.Node) bool { + if t, ok := n.(*parse.RangeNode); ok { + return slices.Contains(template.Identifiers(t.Pipe), "Content") + } + return false + }) + + if tmpl != nil { + // Execute template with test data to see the format + var b bytes.Buffer + if err := tmpl.Execute(&b, map[string][]api.ToolCall{ + "ToolCalls": { + { + Function: api.ToolCallFunction{ + Name: "function_name", + Arguments: api.ToolCallFunctionArguments{ + "argument1": "value1", + // "argument2": "value2", + }, + }, + }, + }, + }); err == nil { + // Look for special tokens in the template output + output := strings.TrimSpace(b.String()) + slog.Debug("tool call template output", "output", output) + if strings.Contains(output, "<") { + // Extract the special token between < and > + start := strings.Index(output, "<") + end := strings.Index(output, ">") + if start >= 0 && end > start { + token := output[start : end+1] + return output, token, true + } + } else if strings.Contains(output, "[") { + // Check if it's a tool call token rather than JSON array + start := strings.Index(output, "[") + end := strings.Index(output, "]") + if start >= 0 && end > start { + token := output[start : end+1] + // Only consider it a token if it's not valid JSON + var jsonTest any + if err := json.Unmarshal([]byte(token), &jsonTest); err != nil { + return output, token, true + } + } + } + } + } + return "", "", false +} + +func parsePythonFunctionCall(s string) (api.ToolCall, bool) { + re := regexp.MustCompile(`(\w+)\((.*?)\)`) + if match := re.FindStringSubmatchIndex(s); match != nil { + name := s[match[2]:match[3]] + args := s[match[4]:match[5]] + + // Check if there's a < after the closing bracket + if idx := strings.Index(s[match[5]:], "<"); idx >= 0 { + // Wait for closing > by returning false + if !strings.Contains(s[match[5]+idx:], ">") { + return api.ToolCall{}, false + } + } arguments := make(api.ToolCallFunctionArguments) - if strings.Contains(args, "=") { // Keyword args - pairs := strings.Split(args, ",") - for _, pair := range pairs { + pairs := strings.SplitSeq(args, ",") + for pair := range pairs { pair = strings.TrimSpace(pair) kv := strings.Split(pair, "=") if len(kv) == 2 { @@ -322,48 +457,179 @@ func (m *Model) ParseToolCallsNew(s string) ([]api.ToolCall, bool) { arguments[key] = value } } - } else { // Positional args - arguments["args"] = args - } - - toolCalls = append(toolCalls, api.ToolCall{ - Function: api.ToolCallFunction{ - Name: name, - Arguments: arguments, - }, - }) - } - - slog.Debug("finished parsing", "tool_calls", len(toolCalls)) - return toolCalls, len(toolCalls) > 0 -} - -func parseJSONToolCalls(obj map[string]any) ([]api.ToolCall, bool) { - // Check for function-style format first - if function, ok := obj["function"].(map[string]any); ok { - name, _ := function["name"].(string) - args, _ := function["arguments"].(map[string]any) - if name != "" && args != nil { - return []api.ToolCall{{ + return api.ToolCall{ Function: api.ToolCallFunction{ Name: name, - Arguments: args, + Arguments: arguments, }, - }}, true + }, true } } - - // Check for direct name/parameters format - if name, ok := obj["name"].(string); ok { - if params, ok := obj["parameters"].(map[string]any); ok { - return []api.ToolCall{{ - Function: api.ToolCallFunction{ - Name: name, - Arguments: params, - }, - }}, true - } - } - - return nil, false + return api.ToolCall{}, false +} + +func (m *Model) ParseToolCallsStream(s string, prefix *string, specialToken *string) ([]api.ToolCall, bool, bool) { + // The prefix check for for the tags shouldn't really be used and we should be consuming this from the model + // Knowing what the tool token enables quicker and more reliable parsing + // TODO: not sure how we're going to handle chatting before the tool call + // TODO: detection would be relying on the model to know what the tool token is + // fmt.Println("parsing tool calls", s) + + if prefix == nil { + prefix = new(string) + *prefix = "" + } + if specialToken == nil { + specialToken = new(string) + *specialToken = "" + } + // TODO: cache this + // _, token, ok := m.GetToolCallFormat(s) + // if ok && token != "" { + // fmt.Println("token", token) + // *specialToken = token + // } + // fmt.Println("prefix", *prefix) + // fmt.Println("special token", *specialToken) + var partial bool + + s = strings.TrimSpace(s) + if len(s) == 0 { + return nil, false, false + } + + if specialToken != nil && len(*specialToken) > 0 { + s2 := *specialToken + if strings.HasPrefix(s, string(s2[0])) { + // fmt.Println("prefix 1 is", string(s2[0])) + partial = true + *prefix = string(s2[0]) + } + } + + if len(s) > 0 { + if s[0] == '[' { + s = strings.ReplaceAll(s, "\n", "") + // tool call list with no special token + if len(s) > 1 && s[1] == '{' { + // fmt.Println("prefix 2 in [{", string(s[0])) + partial = true + *specialToken = "[{" + *prefix = "[{" + } else if *specialToken == "" { + // possible tool call with special token but not in template + // split s over spaces to check for special token + if len(s) > 0 && s[len(s)-1] == ']' { + partial = true + *specialToken = s + *prefix = "[" + } + } + } else if s[0] == '{' { + // fmt.Println("prefix 2 in {", string(s[0])) + partial = true + *specialToken = "{" + *prefix = "{" + } else if s[0] == '<' { + // TODO: the only issue here is that we might miss a > if the token is weird + // The 1 && s[1] == '/' { + // fmt.Println("prefix3 in <", string(s[0])) + // returning a partial here is a hack to ensure that we don't send the content downstream + return nil, true, true + // TODO: jank hack to get special token right + // special token might not be set yet + } else if s[len(s)-1] == '>' { + partial = true + *specialToken = s + *prefix = "<" + } else if specialToken != nil && *specialToken == "" { + partial = true + *specialToken = "<" + *prefix = "<" + } + } + } + + // fmt.Println("special token", *specialToken) + // fmt.Println("prefix", *prefix) + + if !partial { + return nil, false, false + } + // Look for tags + // fmt.Println("looking for special token", *specialToken) + start := strings.Index(s, *specialToken) + if start == -1 { + if partial { + // fmt.Println("did not find opening tag, partial match", *specialToken) + return nil, true, true + } + return nil, false, false + } + end := len(s) + + // Extract content between tags + var content string + // fmt.Println("prefix before is", *prefix) + if *prefix == "[{" || *prefix == "{" { + content = s[start:end] + } else { + content = s[start+len(*specialToken) : end] + } + content = strings.TrimSpace(content) + // fmt.Println("content", content) + + var toolCalls []api.ToolCall + + // Try parsing as JSON first - could be single object or array + var jsonObj any + if err := json.Unmarshal([]byte(content), &jsonObj); err == nil { + // Try as single object + if obj, ok := jsonObj.(map[string]any); ok { + // fmt.Println("obj", obj) + if calls, ok := parseJSONToolCalls(obj); ok { + toolCalls = append(toolCalls, calls...) + } + } + // Try as array of objects + if arr, ok := jsonObj.([]any); ok { + for _, item := range arr { + if obj, ok := item.(map[string]any); ok { + if calls, ok := parseJSONToolCalls(obj); ok { + toolCalls = append(toolCalls, calls...) + } + } + } + } + } else { + // TODO: review this case + // Check for partial JSON before trying Python style + if strings.HasPrefix(content, "{") || strings.HasPrefix(content, "[{") { + // We have an opening brace/bracket but failed to parse - likely partial JSON + return nil, true, true + } + + // Try parsing as Python function call + if toolCall, ok := parsePythonFunctionCall(content); ok { + toolCalls = append(toolCalls, toolCall) + } + } + + // Only return success if we found valid tool calls and no errors + if len(toolCalls) > 0 { + // Check if any of the tool calls are malformed + for _, call := range toolCalls { + if call.Function.Name == "" || len(call.Function.Arguments) == 0 { + return nil, false, false + } + } + return toolCalls, false, true + } + + // fmt.Println("no tool calls found, partial match", partial) + if partial { + return nil, true, true + } + return nil, false, false } diff --git a/server/routes.go b/server/routes.go index 906426b18..554dc4daf 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1526,6 +1526,16 @@ func (s *Server) ChatHandler(c *gin.Context) { defer close(ch) var sb strings.Builder var toolCallIndex int = 0 + var sentWithTools int = 0 + var prefix string + // var specialToken string + _, specialToken, _ := m.GetToolCallFormat(sb.String()) + + var minDuration time.Duration = math.MaxInt64 + var maxDuration time.Duration + var totalDuration time.Duration + var checkCount int + const MAX_TOOL_TOKENS = 6 if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: prompt, Images: images, @@ -1546,6 +1556,14 @@ func (s *Server) ChatHandler(c *gin.Context) { } if r.Done { + slog.Debug("min duration", "duration", minDuration) + slog.Debug("max duration", "duration", maxDuration) + slog.Debug("total duration", "duration", totalDuration) + slog.Debug("check count", "count", checkCount) + // slog.Debug("average duration", "duration", totalDuration/time.Duration(checkCount)) + // if sb.Len() > 0 { + // res.Message.Content = sb.String() + // } res.DoneReason = r.DoneReason.String() res.TotalDuration = time.Since(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart) @@ -1563,25 +1581,46 @@ func (s *Server) ChatHandler(c *gin.Context) { // If tools are recognized, use a flag to track the sending of a tool downstream // This ensures that content is cleared from the message on the last chunk sent sb.WriteString(r.Content) - if toolCalls, ok := m.parseToolCalls(sb.String()); ok { - res.Message.ToolCalls = toolCalls - for i := range toolCalls { - toolCalls[i].Function.Index = toolCallIndex - toolCallIndex++ + // TODO: here we want to prefix check the tool ideally or derive the tool token from the model + // TODO: if we are deriving the tool token, then a heuristic must be applied to stream eventually + // TODO: if the prefix check fails, send the content downstream and reset the builder + startTime := time.Now() + if len(req.Tools) > 0 && sentWithTools < MAX_TOOL_TOKENS { + toolCalls, partial, ok := m.ParseToolCallsStream(sb.String(), &prefix, &specialToken) + duration := time.Since(startTime) + checkCount++ + minDuration = min(minDuration, duration) + maxDuration = max(maxDuration, duration) + totalDuration += duration + slog.Debug("tool call duration", "duration", duration) + if ok { + // fmt.Println("toolCalls", toolCalls, partial, ok, duration) + if partial { + // If the tool call is partial, we need to wait for the next chunk + return + } + res.Message.ToolCalls = toolCalls + for i := range toolCalls { + toolCalls[i].Function.Index = toolCallIndex + toolCallIndex++ + } + sentWithTools = 0 + prefix = "" + specialToken = "" + res.Message.Content = "" + sb.Reset() + ch <- res + return } - res.Message.Content = "" - sb.Reset() - ch <- res - return } - if r.Done { - // Send any remaining content if no tool calls were detected - if toolCallIndex == 0 { - res.Message.Content = sb.String() - } - ch <- res - } + // Send any remaining content if no tool calls were detected + // if toolCallIndex == 0 { + // fmt.Println("toolCallIndex", toolCallIndex) + sentWithTools++ + res.Message.Content = sb.String() + sb.Reset() + ch <- res }); err != nil { ch <- gin.H{"error": err.Error()} } @@ -1590,11 +1629,35 @@ func (s *Server) ChatHandler(c *gin.Context) { if req.Stream != nil && !*req.Stream { var resp api.ChatResponse var sb strings.Builder + var toolCalls []api.ToolCall + var prefix string + var specialToken string + const MAX_TOOL_TOKENS = 6 + sentWithTools := 0 + var tb strings.Builder for rr := range ch { switch t := rr.(type) { case api.ChatResponse: sb.WriteString(t.Message.Content) resp = t + if len(req.Tools) > 0 && sentWithTools < MAX_TOOL_TOKENS { + tb.WriteString(t.Message.Content) + if tcs, partial, ok := m.ParseToolCallsStream(tb.String(), &prefix, &specialToken); ok { + if !partial { + // resp.Message.ToolCalls = toolCalls + toolCalls = append(toolCalls, tcs...) + resp.Message.Content = "" + tb.Reset() + prefix = "" + specialToken = "" + } + } else { + // equivalent to no partial - send the content downstream + tb.Reset() + sentWithTools++ + + } + } case gin.H: msg, ok := t["error"].(string) if !ok { @@ -1610,14 +1673,18 @@ func (s *Server) ChatHandler(c *gin.Context) { } resp.Message.Content = sb.String() - - if len(req.Tools) > 0 { - if toolCalls, ok := m.parseToolCalls(sb.String()); ok { - resp.Message.ToolCalls = toolCalls - resp.Message.Content = "" - } + if len(toolCalls) > 0 { + resp.Message.ToolCalls = toolCalls + // resp.Message.Content = "" } + // if len(req.Tools) > 0 { + // if toolCalls, ok := m.ParseToolCalls(sb.String()); ok { + // resp.Message.ToolCalls = toolCalls + // resp.Message.Content = "" + // } + // } + c.JSON(http.StatusOK, resp) return }