Compare commits
	
		
			4 Commits
		
	
	
		
			main
			...
			parth/pyth
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | b4cd1118ab | ||
|   | 128c90d3ac | ||
|   | f5872a097c | ||
|   | 3ac5e0f102 | 
							
								
								
									
										386
									
								
								server/model.go
									
									
									
									
									
								
							
							
						
						
									
										386
									
								
								server/model.go
									
									
									
									
									
								
							| @@ -10,6 +10,7 @@ import ( | ||||
| 	"log/slog" | ||||
| 	"net/http" | ||||
| 	"os" | ||||
| 	"regexp" | ||||
| 	"slices" | ||||
| 	"strings" | ||||
| 	"text/template/parse" | ||||
| @@ -153,99 +154,342 @@ func parseObjects(s string) []map[string]any { | ||||
| 	return objs | ||||
| } | ||||
|  | ||||
| // parseToolCalls attempts to parse a JSON string into a slice of ToolCalls. | ||||
| // mxyng: this only really works if the input contains tool calls in some JSON format | ||||
| func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { | ||||
| 	// create a subtree from the node that ranges over .ToolCalls | ||||
| // Get tool call token from model template | ||||
| func (m *Model) TemplateToolToken() (string, string, bool) { | ||||
| 	// Try to detect the tool call format from the model's template | ||||
| 	tmpl := m.Template.Subtree(func(n parse.Node) bool { | ||||
| 		if t, ok := n.(*parse.RangeNode); ok { | ||||
| 			return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls") | ||||
| 		} | ||||
|  | ||||
| 		return false | ||||
| 	}) | ||||
|  | ||||
| 	if tmpl == nil { | ||||
| 		return nil, false | ||||
| 	} | ||||
|  | ||||
| 	var b bytes.Buffer | ||||
| 	if err := tmpl.Execute(&b, map[string][]api.ToolCall{ | ||||
| 		"ToolCalls": { | ||||
| 			{ | ||||
| 				Function: api.ToolCallFunction{ | ||||
| 					Name: "@@name@@", | ||||
| 					Arguments: api.ToolCallFunctionArguments{ | ||||
| 						"@@argument@@": 1, | ||||
| 	// fmt.Println("tool call template", tmpl) | ||||
| 	if tmpl != nil { | ||||
| 		// Execute template with test data to see the format | ||||
| 		var b bytes.Buffer | ||||
| 		if err := tmpl.Execute(&b, map[string][]api.ToolCall{ | ||||
| 			"ToolCalls": { | ||||
| 				{ | ||||
| 					Function: api.ToolCallFunction{ | ||||
| 						Name: "function_name", | ||||
| 						Arguments: api.ToolCallFunctionArguments{ | ||||
| 							"argument1": "value1", | ||||
| 							// "argument2": "value2", | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	}); err != nil { | ||||
| 		return nil, false | ||||
| 	} | ||||
|  | ||||
| 	templateObjects := parseObjects(b.String()) | ||||
| 	if len(templateObjects) == 0 { | ||||
| 		return nil, false | ||||
| 	} | ||||
|  | ||||
| 	// find the keys that correspond to the name and arguments fields | ||||
| 	var name, arguments string | ||||
| 	for k, v := range templateObjects[0] { | ||||
| 		switch v.(type) { | ||||
| 		case string: | ||||
| 			name = k | ||||
| 		case map[string]any: | ||||
| 			arguments = k | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if name == "" || arguments == "" { | ||||
| 		return nil, false | ||||
| 	} | ||||
|  | ||||
| 	responseObjects := parseObjects(s) | ||||
| 	if len(responseObjects) == 0 { | ||||
| 		return nil, false | ||||
| 	} | ||||
|  | ||||
| 	// collect all nested objects | ||||
| 	var collect func(any) []map[string]any | ||||
| 	collect = func(obj any) (all []map[string]any) { | ||||
| 		switch o := obj.(type) { | ||||
| 		case map[string]any: | ||||
| 			all = append(all, o) | ||||
| 			for _, v := range o { | ||||
| 				all = append(all, collect(v)...) | ||||
| 			} | ||||
| 		case []any: | ||||
| 			for _, v := range o { | ||||
| 				all = append(all, collect(v)...) | ||||
| 		}); err == nil { | ||||
| 			// Look for special tokens in the template output | ||||
| 			output := strings.TrimSpace(b.String()) | ||||
| 			slog.Debug("tool call template output", "output", output) | ||||
| 			if strings.Contains(output, "<") { | ||||
| 				// Extract the special token between < and > | ||||
| 				start := strings.Index(output, "<") | ||||
| 				end := strings.Index(output, ">") | ||||
| 				if start >= 0 && end > start { | ||||
| 					token := output[start : end+1] | ||||
| 					return output, token, true | ||||
| 				} | ||||
| 			} else if strings.Contains(output, "[") { | ||||
| 				// Check if it's a tool call token rather than JSON array | ||||
| 				start := strings.Index(output, "[") | ||||
| 				end := strings.Index(output, "]") | ||||
| 				if start >= 0 && end > start { | ||||
| 					token := output[start : end+1] | ||||
| 					// Only consider it a token if it's not valid JSON | ||||
| 					var jsonTest any | ||||
| 					if err := json.Unmarshal([]byte(token), &jsonTest); err != nil { | ||||
| 						return output, token, true | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		return all | ||||
| 	} | ||||
| 	return "", "", false | ||||
| } | ||||
|  | ||||
| 	var objs []map[string]any | ||||
| 	for _, p := range responseObjects { | ||||
| 		objs = append(objs, collect(p)...) | ||||
| func parsePythonFunctionCall(s string) ([]api.ToolCall, bool) { | ||||
| 	re := regexp.MustCompile(`(\w+)\((.*?)\)`) | ||||
| 	matches := re.FindAllStringSubmatchIndex(s, -1) | ||||
| 	if len(matches) == 0 { | ||||
| 		return nil, false | ||||
| 	} | ||||
|  | ||||
| 	var toolCalls []api.ToolCall | ||||
| 	for _, kv := range objs { | ||||
| 		n, nok := kv[name].(string) | ||||
| 		a, aok := kv[arguments].(map[string]any) | ||||
| 		if nok && aok { | ||||
| 	for _, match := range matches { | ||||
| 		name := s[match[2]:match[3]] | ||||
| 		args := s[match[4]:match[5]] | ||||
|  | ||||
| 		arguments := make(api.ToolCallFunctionArguments) | ||||
| 		if strings.Contains(args, "=") { // Keyword args | ||||
| 			pairs := strings.SplitSeq(args, ",") | ||||
| 			for pair := range pairs { | ||||
| 				pair = strings.TrimSpace(pair) | ||||
| 				kv := strings.Split(pair, "=") | ||||
| 				if len(kv) == 2 { | ||||
| 					key := strings.TrimSpace(kv[0]) | ||||
| 					value := strings.TrimSpace(kv[1]) | ||||
| 					arguments[key] = value | ||||
| 				} | ||||
| 			} | ||||
| 			toolCalls = append(toolCalls, api.ToolCall{ | ||||
| 				Function: api.ToolCallFunction{ | ||||
| 					Name:      n, | ||||
| 					Arguments: a, | ||||
| 					Name:      name, | ||||
| 					Arguments: arguments, | ||||
| 				}, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return toolCalls, len(toolCalls) > 0 | ||||
| 	if len(toolCalls) > 0 { | ||||
| 		return toolCalls, true | ||||
| 	} | ||||
| 	return nil, false | ||||
| } | ||||
|  | ||||
| // ToolCallFormat represents different possible formats for tool calls | ||||
| type toolCallFormat struct { | ||||
| 	// Direct format | ||||
| 	Name      string         `json:"name,omitempty"` | ||||
| 	Arguments map[string]any `json:"arguments,omitempty"` | ||||
|  | ||||
| 	// Command-r-plus format | ||||
| 	ToolName   string         `json:"tool_name,omitempty"` | ||||
| 	Parameters map[string]any `json:"parameters,omitempty"` | ||||
|  | ||||
| 	// Function format | ||||
| 	Function *struct { | ||||
| 		Name       string         `json:"name"` | ||||
| 		Arguments  map[string]any `json:"arguments,omitempty"` | ||||
| 		Parameters map[string]any `json:"parameters,omitempty"` | ||||
| 	} `json:"function,omitempty"` | ||||
|  | ||||
| 	// Xlam format | ||||
| 	ToolCalls []toolCallFormat `json:"tool_calls,omitempty"` | ||||
| } | ||||
|  | ||||
| func parseJSONToolCalls(obj map[string]any) ([]api.ToolCall, bool) { | ||||
| 	// Helper to convert any to []any safely | ||||
| 	toArray := func(v any) []any { | ||||
| 		if arr, ok := v.([]any); ok { | ||||
| 			return arr | ||||
| 		} | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	// Convert a single format to a tool call | ||||
| 	makeToolCall := func(f toolCallFormat) (api.ToolCall, bool) { | ||||
| 		switch { | ||||
| 		case f.Name != "" && f.Arguments != nil: | ||||
| 			return api.ToolCall{ | ||||
| 				Function: api.ToolCallFunction{ | ||||
| 					Name:      f.Name, | ||||
| 					Arguments: f.Arguments, | ||||
| 				}, | ||||
| 			}, true | ||||
| 		case f.Name != "" && f.Parameters != nil: // Handle parameters field | ||||
| 			return api.ToolCall{ | ||||
| 				Function: api.ToolCallFunction{ | ||||
| 					Name:      f.Name, | ||||
| 					Arguments: f.Parameters, | ||||
| 				}, | ||||
| 			}, true | ||||
| 		case f.ToolName != "" && f.Parameters != nil: | ||||
| 			return api.ToolCall{ | ||||
| 				Function: api.ToolCallFunction{ | ||||
| 					Name:      f.ToolName, | ||||
| 					Arguments: f.Parameters, | ||||
| 				}, | ||||
| 			}, true | ||||
| 		case f.Function != nil && f.Function.Name != "": | ||||
| 			args := f.Function.Arguments | ||||
| 			if args == nil { | ||||
| 				args = f.Function.Parameters | ||||
| 			} | ||||
| 			if args != nil { | ||||
| 				return api.ToolCall{ | ||||
| 					Function: api.ToolCallFunction{ | ||||
| 						Name:      f.Function.Name, | ||||
| 						Arguments: args, | ||||
| 					}, | ||||
| 				}, true | ||||
| 			} | ||||
| 		} | ||||
| 		return api.ToolCall{}, false | ||||
| 	} | ||||
|  | ||||
| 	// Try parsing as array first | ||||
| 	if arr := toArray(obj); arr != nil { | ||||
| 		var calls []api.ToolCall | ||||
| 		for _, item := range arr { | ||||
| 			if itemMap, ok := item.(map[string]any); ok { | ||||
| 				var format toolCallFormat | ||||
| 				data, _ := json.Marshal(itemMap) | ||||
| 				if err := json.Unmarshal(data, &format); err == nil { | ||||
| 					if call, ok := makeToolCall(format); ok { | ||||
| 						calls = append(calls, call) | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		if len(calls) > 0 { | ||||
| 			return calls, true | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Try parsing as single object | ||||
| 	var format toolCallFormat | ||||
| 	data, _ := json.Marshal(obj) | ||||
| 	if err := json.Unmarshal(data, &format); err != nil { | ||||
| 		return nil, false | ||||
| 	} | ||||
|  | ||||
| 	// Handle xlam format (tool_calls array) | ||||
| 	if len(format.ToolCalls) > 0 { | ||||
| 		var calls []api.ToolCall | ||||
| 		for _, f := range format.ToolCalls { | ||||
| 			if call, ok := makeToolCall(f); ok { | ||||
| 				calls = append(calls, call) | ||||
| 			} | ||||
| 		} | ||||
| 		if len(calls) > 0 { | ||||
| 			return calls, true | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Try as single tool call | ||||
| 	if call, ok := makeToolCall(format); ok { | ||||
| 		return []api.ToolCall{call}, true | ||||
| 	} | ||||
|  | ||||
| 	return nil, false | ||||
| } | ||||
|  | ||||
| // token, partial, success | ||||
| func deriveToolToken(s string, prefix string) (string, bool, bool) { | ||||
| 	// There shouldn't be spaces in a tool token | ||||
| 	if len(strings.Fields(s)) > 1 { | ||||
| 		return "", false, false | ||||
| 	} | ||||
|  | ||||
| 	if prefix == "[" && len(s) > 1 && s[len(s)-1] == ']' { | ||||
| 		return s, false, true | ||||
| 	} else if prefix == "<" && len(s) > 1 && s[len(s)-1] == '>' { | ||||
| 		return s, false, true | ||||
| 	} | ||||
| 	return "", true, true | ||||
| } | ||||
|  | ||||
| func parseJSON(s string) ([]api.ToolCall, bool) { | ||||
| 	objs := parseObjects(s) | ||||
| 	tcs := []api.ToolCall{} | ||||
| 	for _, obj := range objs { | ||||
| 		toolCalls, ok := parseJSONToolCalls(obj) | ||||
| 		if ok { | ||||
| 			tcs = append(tcs, toolCalls...) | ||||
| 		} | ||||
| 	} | ||||
| 	if len(tcs) > 0 { | ||||
| 		return tcs, true | ||||
| 	} | ||||
| 	return nil, false | ||||
| } | ||||
|  | ||||
| // returns tool calls, partial, success | ||||
| func (m *Model) ParseToolCalls(s string, toolToken *string) ([]api.ToolCall, bool, bool) { | ||||
| 	// [ case can either be JSON, Python or a Tool Token | ||||
| 	s = strings.TrimSpace(s) | ||||
| 	fmt.Printf("ParseToolCallsNew input: %q\n", s) | ||||
| 	if len(s) == 0 { | ||||
| 		return nil, false, false | ||||
| 	} | ||||
|  | ||||
| 	if strings.HasPrefix(s, "[") { | ||||
| 		fmt.Println("Found [ prefix") | ||||
| 		// JSON case | ||||
| 		// we do not consider array JSONs as tool calls | ||||
| 		if strings.HasPrefix(s, "[{") { | ||||
| 			fmt.Println("Found [{ prefix - attempting JSON parse") | ||||
| 			// TODO: mark as JSON partial | ||||
| 			if calls, ok := parseJSON(s); ok { | ||||
| 				fmt.Printf("Successfully parsed JSON, found %d calls\n", len(calls)) | ||||
| 				return calls, false, true | ||||
| 			} | ||||
| 			return nil, true, true | ||||
| 		} | ||||
| 		// Python Case | ||||
| 		// We just do a full python check here | ||||
| 		fmt.Println("Attempting Python function parse") | ||||
| 		tc, ok := parsePythonFunctionCall(s) | ||||
| 		if ok { | ||||
| 			fmt.Printf("Successfully parsed Python function: %+v\n", tc) | ||||
| 			return tc, false, true | ||||
| 		} | ||||
| 		// Tool Token Case - this is okay if it's a real tool token and we couldn't get from template | ||||
| 		fmt.Println("Attempting to derive tool token") | ||||
| 		if toolToken == nil || *toolToken == "" { | ||||
| 			toolTok, partial, ok := deriveToolToken(s, "[") | ||||
| 			if !ok { | ||||
| 				return nil, false, false | ||||
| 			} | ||||
| 			if partial { | ||||
| 				return nil, true, true | ||||
| 			} | ||||
| 			*toolToken = toolTok | ||||
| 		} | ||||
| 		fmt.Printf("Found tool token: %q\n", *toolToken) | ||||
| 		s = strings.TrimSpace(s[len(*toolToken):]) | ||||
| 		fmt.Printf("Recursing with remaining string: %q\n", s) | ||||
| 		if toolCalls, partial, ok := m.ParseToolCalls(s, toolToken); ok { | ||||
| 			return toolCalls, partial, true | ||||
| 		} | ||||
| 		return nil, true, true | ||||
| 	} else if strings.HasPrefix(s, "{") || strings.HasPrefix(s, "```") { | ||||
| 		// // TODO: temp fix | ||||
| 		// if strings.HasPrefix(s, "```") && len(s) == 3 { | ||||
| 		// 	return nil, false, false | ||||
| 		// } | ||||
| 		fmt.Println("Found { prefix - attempting JSON parse with ", s) | ||||
| 		if calls, ok := parseJSON(s); ok { | ||||
| 			fmt.Printf("Successfully parsed JSON object, found %d calls\n", len(calls)) | ||||
| 			return calls, false, true | ||||
| 		} | ||||
| 		fmt.Println("Failed to parse JSON in JSON case") | ||||
| 		// TODO: possible case where it never finishes parsing - then what? | ||||
| 		return nil, true, true | ||||
| 	} else if strings.HasPrefix(s, "<") { | ||||
| 		fmt.Println("Found < prefix - attempting to derive tool token") | ||||
| 		if toolToken == nil || *toolToken == "" { | ||||
| 			toolTok, partial, ok := deriveToolToken(s, "<") | ||||
| 			if !ok { | ||||
| 				return nil, false, false | ||||
| 			} | ||||
| 			if partial { | ||||
| 				return nil, true, true | ||||
| 			} | ||||
| 			*toolToken = toolTok | ||||
| 			fmt.Printf("Found tool token: %q\n", *toolToken) | ||||
| 		} | ||||
| 		fmt.Printf("Found tool token: %q\n", *toolToken) | ||||
| 		s = strings.TrimSpace(s[len(*toolToken):]) | ||||
| 		fmt.Printf("Recursing with remaining string: %q\n", s) | ||||
| 		if toolCalls, partial, ok := m.ParseToolCalls(s, toolToken); ok { | ||||
| 			return toolCalls, partial, true | ||||
| 		} | ||||
| 		return nil, true, true | ||||
| 	} else if strings.Contains(s, "(") || len(strings.Fields(s)) == 1 { | ||||
| 		fmt.Println("Attempting Python function parse") | ||||
| 		tc, ok := parsePythonFunctionCall(s) | ||||
| 		if ok { | ||||
| 			fmt.Printf("Successfully parsed Python function: %+v\n", tc) | ||||
| 			return tc, false, true | ||||
| 		} | ||||
| 		fmt.Printf("Failed to parse Python function: %q, returning partial", s) | ||||
| 		return nil, true, true | ||||
| 	} | ||||
| 	fmt.Println("No successful parse paths found") | ||||
| 	fmt.Printf("failed string: %q\n", s) | ||||
| 	return nil, false, false | ||||
| } | ||||
|   | ||||
							
								
								
									
										112
									
								
								server/routes.go
									
									
									
									
									
								
							
							
						
						
									
										112
									
								
								server/routes.go
									
									
									
									
									
								
							| @@ -1526,6 +1526,17 @@ func (s *Server) ChatHandler(c *gin.Context) { | ||||
| 		defer close(ch) | ||||
| 		var sb strings.Builder | ||||
| 		var toolCallIndex int = 0 | ||||
| 		var sentWithTools int = 0 | ||||
| 		// var prefix string | ||||
| 		// var templateToolToken string | ||||
| 		_, templateToolToken, _ := m.TemplateToolToken() | ||||
| 		// fmt.Println("special token", templateToolToken) | ||||
|  | ||||
| 		var minDuration time.Duration = math.MaxInt64 | ||||
| 		var maxDuration time.Duration | ||||
| 		var totalDuration time.Duration | ||||
| 		var checkCount int | ||||
| 		const maxToolTokens = 1 | ||||
| 		if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ | ||||
| 			Prompt:  prompt, | ||||
| 			Images:  images, | ||||
| @@ -1546,6 +1557,14 @@ func (s *Server) ChatHandler(c *gin.Context) { | ||||
| 			} | ||||
|  | ||||
| 			if r.Done { | ||||
| 				slog.Debug("min duration", "duration", minDuration) | ||||
| 				slog.Debug("max duration", "duration", maxDuration) | ||||
| 				slog.Debug("total duration", "duration", totalDuration) | ||||
| 				slog.Debug("check count", "count", checkCount) | ||||
| 				// slog.Debug("average duration", "duration", totalDuration/time.Duration(checkCount)) | ||||
| 				// if sb.Len() > 0 { | ||||
| 				// 	res.Message.Content = sb.String() | ||||
| 				// } | ||||
| 				res.DoneReason = r.DoneReason.String() | ||||
| 				res.TotalDuration = time.Since(checkpointStart) | ||||
| 				res.LoadDuration = checkpointLoaded.Sub(checkpointStart) | ||||
| @@ -1563,25 +1582,48 @@ func (s *Server) ChatHandler(c *gin.Context) { | ||||
| 			// If tools are recognized, use a flag to track the sending of a tool downstream | ||||
| 			// This ensures that content is cleared from the message on the last chunk sent | ||||
| 			sb.WriteString(r.Content) | ||||
| 			if toolCalls, ok := m.parseToolCalls(sb.String()); ok { | ||||
| 				res.Message.ToolCalls = toolCalls | ||||
| 				for i := range toolCalls { | ||||
| 					toolCalls[i].Function.Index = toolCallIndex | ||||
| 					toolCallIndex++ | ||||
| 			startTime := time.Now() | ||||
| 			// TODO: work max tool tok logic | ||||
| 			if len(req.Tools) > 0 && sentWithTools < maxToolTokens { | ||||
| 				toolCalls, partial, ok := m.ParseToolCalls(sb.String(), &templateToolToken) | ||||
| 				duration := time.Since(startTime) | ||||
| 				checkCount++ | ||||
| 				minDuration = min(minDuration, duration) | ||||
| 				maxDuration = max(maxDuration, duration) | ||||
| 				totalDuration += duration | ||||
| 				slog.Debug("tool call duration", "duration", duration) | ||||
| 				if ok { | ||||
| 					// fmt.Println("toolCalls", toolCalls, partial, ok, duration) | ||||
| 					if partial { | ||||
| 						// If the tool call is partial, we need to wait for the next chunk | ||||
| 						return | ||||
| 					} | ||||
| 					slog.Debug("toolCalls", "toolCalls", toolCalls, "partial", partial, "ok", ok) | ||||
| 					res.Message.ToolCalls = toolCalls | ||||
| 					for i := range toolCalls { | ||||
| 						toolCalls[i].Function.Index = toolCallIndex | ||||
| 						toolCallIndex++ | ||||
| 					} | ||||
| 					sentWithTools = 0 | ||||
| 					// prefix = "" | ||||
| 					templateToolToken = "" | ||||
| 					res.Message.Content = "" | ||||
| 					sb.Reset() | ||||
| 					ch <- res | ||||
| 					// TODO: revisit this | ||||
| 					sentWithTools++ | ||||
| 					slog.Debug("fired on tool call", "toolCalls", toolCalls, "toolCallIndex", toolCallIndex) | ||||
| 					return | ||||
| 				} | ||||
| 				res.Message.Content = "" | ||||
| 				sb.Reset() | ||||
| 				ch <- res | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			if r.Done { | ||||
| 				// Send any remaining content if no tool calls were detected | ||||
| 				if toolCallIndex == 0 { | ||||
| 					res.Message.Content = sb.String() | ||||
| 				} | ||||
| 				ch <- res | ||||
| 			} | ||||
| 			// Send any remaining content if no tool calls were detected | ||||
| 			// if toolCallIndex == 0 { | ||||
| 			// fmt.Println("toolCallIndex", toolCallIndex) | ||||
| 			sentWithTools++ | ||||
| 			res.Message.Content = sb.String() | ||||
| 			sb.Reset() | ||||
| 			ch <- res | ||||
| 		}); err != nil { | ||||
| 			ch <- gin.H{"error": err.Error()} | ||||
| 		} | ||||
| @@ -1590,11 +1632,33 @@ func (s *Server) ChatHandler(c *gin.Context) { | ||||
| 	if req.Stream != nil && !*req.Stream { | ||||
| 		var resp api.ChatResponse | ||||
| 		var sb strings.Builder | ||||
| 		var toolCalls []api.ToolCall | ||||
| 		const MAX_TOOL_TOKENS = 1 | ||||
| 		sentWithTools := 0 | ||||
| 		var tb strings.Builder | ||||
| 		_, templateToolToken, _ := m.TemplateToolToken() | ||||
| 		for rr := range ch { | ||||
| 			switch t := rr.(type) { | ||||
| 			case api.ChatResponse: | ||||
| 				sb.WriteString(t.Message.Content) | ||||
| 				resp = t | ||||
| 				// TODO: work max tool tok logic | ||||
| 				if len(req.Tools) > 0 && sentWithTools < MAX_TOOL_TOKENS { | ||||
| 					tb.WriteString(t.Message.Content) | ||||
| 					if tcs, partial, ok := m.ParseToolCalls(tb.String(), &templateToolToken); ok { | ||||
| 						if !partial { | ||||
| 							// resp.Message.ToolCalls = toolCalls | ||||
| 							toolCalls = append(toolCalls, tcs...) | ||||
| 							resp.Message.Content = "" | ||||
| 							tb.Reset() | ||||
| 						} | ||||
| 					} else { | ||||
| 						// equivalent to no partial - send the content downstream | ||||
| 						tb.Reset() | ||||
| 						sentWithTools++ | ||||
|  | ||||
| 					} | ||||
| 				} | ||||
| 			case gin.H: | ||||
| 				msg, ok := t["error"].(string) | ||||
| 				if !ok { | ||||
| @@ -1610,14 +1674,18 @@ func (s *Server) ChatHandler(c *gin.Context) { | ||||
| 		} | ||||
|  | ||||
| 		resp.Message.Content = sb.String() | ||||
|  | ||||
| 		if len(req.Tools) > 0 { | ||||
| 			if toolCalls, ok := m.parseToolCalls(sb.String()); ok { | ||||
| 				resp.Message.ToolCalls = toolCalls | ||||
| 				resp.Message.Content = "" | ||||
| 			} | ||||
| 		if len(toolCalls) > 0 { | ||||
| 			resp.Message.ToolCalls = toolCalls | ||||
| 			// resp.Message.Content = "" | ||||
| 		} | ||||
|  | ||||
| 		// if len(req.Tools) > 0 { | ||||
| 		// 	if toolCalls, ok := m.ParseToolCalls(sb.String()); ok { | ||||
| 		// 		resp.Message.ToolCalls = toolCalls | ||||
| 		// 		resp.Message.Content = "" | ||||
| 		// 	} | ||||
| 		// } | ||||
|  | ||||
| 		c.JSON(http.StatusOK, resp) | ||||
| 		return | ||||
| 	} | ||||
|   | ||||
		Reference in New Issue
	
	Block a user