From d2b25c1bfb4fe9a2877a2e3fde24c0e46545a6ee Mon Sep 17 00:00:00 2001 From: Roy Han Date: Mon, 29 Jul 2024 16:59:02 -0700 Subject: [PATCH] draft --- server/model.go | 4 ++++ server/routes.go | 23 ++++++++++++++++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/server/model.go b/server/model.go index c6d3078f1..b734a298b 100644 --- a/server/model.go +++ b/server/model.go @@ -15,6 +15,7 @@ import ( "slices" "strings" "text/template/parse" + "time" "github.com/ollama/ollama/api" "github.com/ollama/ollama/convert" @@ -312,6 +313,7 @@ func detectContentType(r io.Reader) (string, error) { // mxyng: this only really works if the input contains tool calls in some JSON format func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { // create a subtree from the node that ranges over .ToolCalls + start := time.Now() tmpl := m.Template.Subtree(func(n parse.Node) bool { if t, ok := n.(*parse.RangeNode); ok { return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls") @@ -415,5 +417,7 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { } } + end := time.Now() + slog.Debug("parseToolCalls", "duration", end.Sub(start).String()) return toolCalls, len(toolCalls) > 0 } diff --git a/server/routes.go b/server/routes.go index e6ffe5268..63eac5634 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1369,7 +1369,7 @@ func (s *Server) ChatHandler(c *gin.Context) { } }() - if req.Stream != nil && !*req.Stream { + if (req.Stream != nil && !*req.Stream) || ((req.Stream == nil || *req.Stream) && len(req.Tools) > 0) { var resp api.ChatResponse var sb strings.Builder for rr := range ch { @@ -1400,6 +1400,27 @@ func (s *Server) ChatHandler(c *gin.Context) { } } + if (req.Stream == nil || *req.Stream) && len(resp.Message.ToolCalls) > 0 { + toolCh := make(chan any) + go func() { + toolCalls := resp.Message.ToolCalls + for _, toolCall := range toolCalls[:len(toolCalls)-1] { + chunk := api.ChatResponse{ + Model: resp.Model, + CreatedAt: resp.CreatedAt, + Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{toolCall}}, + DoneReason: "tool_calls", + } + toolCh <- chunk + } + resp.Message.ToolCalls = []api.ToolCall{toolCalls[len(toolCalls)-1]} + toolCh <- resp + close(toolCh) + }() + streamResponse(c, toolCh) + return + } + c.JSON(http.StatusOK, resp) return }