draft
This commit is contained in:
parent
1a83581a8e
commit
d2b25c1bfb
@ -15,6 +15,7 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"text/template/parse"
|
"text/template/parse"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/convert"
|
"github.com/ollama/ollama/convert"
|
||||||
@ -312,6 +313,7 @@ func detectContentType(r io.Reader) (string, error) {
|
|||||||
// mxyng: this only really works if the input contains tool calls in some JSON format
|
// mxyng: this only really works if the input contains tool calls in some JSON format
|
||||||
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
||||||
// create a subtree from the node that ranges over .ToolCalls
|
// create a subtree from the node that ranges over .ToolCalls
|
||||||
|
start := time.Now()
|
||||||
tmpl := m.Template.Subtree(func(n parse.Node) bool {
|
tmpl := m.Template.Subtree(func(n parse.Node) bool {
|
||||||
if t, ok := n.(*parse.RangeNode); ok {
|
if t, ok := n.(*parse.RangeNode); ok {
|
||||||
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
||||||
@ -415,5 +417,7 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
end := time.Now()
|
||||||
|
slog.Debug("parseToolCalls", "duration", end.Sub(start).String())
|
||||||
return toolCalls, len(toolCalls) > 0
|
return toolCalls, len(toolCalls) > 0
|
||||||
}
|
}
|
||||||
|
@ -1369,7 +1369,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if req.Stream != nil && !*req.Stream {
|
if (req.Stream != nil && !*req.Stream) || ((req.Stream == nil || *req.Stream) && len(req.Tools) > 0) {
|
||||||
var resp api.ChatResponse
|
var resp api.ChatResponse
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
for rr := range ch {
|
for rr := range ch {
|
||||||
@ -1400,6 +1400,27 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (req.Stream == nil || *req.Stream) && len(resp.Message.ToolCalls) > 0 {
|
||||||
|
toolCh := make(chan any)
|
||||||
|
go func() {
|
||||||
|
toolCalls := resp.Message.ToolCalls
|
||||||
|
for _, toolCall := range toolCalls[:len(toolCalls)-1] {
|
||||||
|
chunk := api.ChatResponse{
|
||||||
|
Model: resp.Model,
|
||||||
|
CreatedAt: resp.CreatedAt,
|
||||||
|
Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{toolCall}},
|
||||||
|
DoneReason: "tool_calls",
|
||||||
|
}
|
||||||
|
toolCh <- chunk
|
||||||
|
}
|
||||||
|
resp.Message.ToolCalls = []api.ToolCall{toolCalls[len(toolCalls)-1]}
|
||||||
|
toolCh <- resp
|
||||||
|
close(toolCh)
|
||||||
|
}()
|
||||||
|
streamResponse(c, toolCh)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, resp)
|
c.JSON(http.StatusOK, resp)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user