Compare commits
2 Commits
main
...
parth/serv
Author | SHA1 | Date | |
---|---|---|---|
![]() |
3bc9d42e2e | ||
![]() |
4053c489b4 |
@ -20,6 +20,7 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"text/template/parse"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
@ -62,6 +63,7 @@ type Model struct {
|
|||||||
Digest string
|
Digest string
|
||||||
Options map[string]any
|
Options map[string]any
|
||||||
Messages []api.Message
|
Messages []api.Message
|
||||||
|
ToolPrefix string
|
||||||
|
|
||||||
Template *template.Template
|
Template *template.Template
|
||||||
}
|
}
|
||||||
@ -260,7 +262,7 @@ func GetModel(name string) (*Model, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
model := &Model{
|
m := &Model{
|
||||||
Name: mp.GetFullTagname(),
|
Name: mp.GetFullTagname(),
|
||||||
ShortName: mp.GetShortTagname(),
|
ShortName: mp.GetShortTagname(),
|
||||||
Digest: digest,
|
Digest: digest,
|
||||||
@ -279,7 +281,7 @@ func GetModel(name string) (*Model, error) {
|
|||||||
}
|
}
|
||||||
defer configFile.Close()
|
defer configFile.Close()
|
||||||
|
|
||||||
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
|
if err := json.NewDecoder(configFile).Decode(&m.Config); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -292,16 +294,16 @@ func GetModel(name string) (*Model, error) {
|
|||||||
|
|
||||||
switch layer.MediaType {
|
switch layer.MediaType {
|
||||||
case "application/vnd.ollama.image.model":
|
case "application/vnd.ollama.image.model":
|
||||||
model.ModelPath = filename
|
m.ModelPath = filename
|
||||||
model.ParentModel = layer.From
|
m.ParentModel = layer.From
|
||||||
case "application/vnd.ollama.image.embed":
|
case "application/vnd.ollama.image.embed":
|
||||||
// Deprecated in versions > 0.1.2
|
// Deprecated in versions > 0.1.2
|
||||||
// TODO: remove this warning in a future version
|
// TODO: remove this warning in a future version
|
||||||
slog.Info("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
|
slog.Info("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
|
||||||
case "application/vnd.ollama.image.adapter":
|
case "application/vnd.ollama.image.adapter":
|
||||||
model.AdapterPaths = append(model.AdapterPaths, filename)
|
m.AdapterPaths = append(m.AdapterPaths, filename)
|
||||||
case "application/vnd.ollama.image.projector":
|
case "application/vnd.ollama.image.projector":
|
||||||
model.ProjectorPaths = append(model.ProjectorPaths, filename)
|
m.ProjectorPaths = append(m.ProjectorPaths, filename)
|
||||||
case "application/vnd.ollama.image.prompt",
|
case "application/vnd.ollama.image.prompt",
|
||||||
"application/vnd.ollama.image.template":
|
"application/vnd.ollama.image.template":
|
||||||
bts, err := os.ReadFile(filename)
|
bts, err := os.ReadFile(filename)
|
||||||
@ -309,7 +311,7 @@ func GetModel(name string) (*Model, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
model.Template, err = template.Parse(string(bts))
|
m.Template, err = template.Parse(string(bts))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -319,7 +321,7 @@ func GetModel(name string) (*Model, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
model.System = string(bts)
|
m.System = string(bts)
|
||||||
case "application/vnd.ollama.image.params":
|
case "application/vnd.ollama.image.params":
|
||||||
params, err := os.Open(filename)
|
params, err := os.Open(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -328,7 +330,7 @@ func GetModel(name string) (*Model, error) {
|
|||||||
defer params.Close()
|
defer params.Close()
|
||||||
|
|
||||||
// parse model options parameters into a map so that we can see which fields have been specified explicitly
|
// parse model options parameters into a map so that we can see which fields have been specified explicitly
|
||||||
if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
|
if err = json.NewDecoder(params).Decode(&m.Options); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
case "application/vnd.ollama.image.messages":
|
case "application/vnd.ollama.image.messages":
|
||||||
@ -338,7 +340,7 @@ func GetModel(name string) (*Model, error) {
|
|||||||
}
|
}
|
||||||
defer msgs.Close()
|
defer msgs.Close()
|
||||||
|
|
||||||
if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
|
if err = json.NewDecoder(msgs).Decode(&m.Messages); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
case "application/vnd.ollama.image.license":
|
case "application/vnd.ollama.image.license":
|
||||||
@ -346,11 +348,50 @@ func GetModel(name string) (*Model, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
model.License = append(model.License, string(bts))
|
m.License = append(m.License, string(bts))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return model, nil
|
capabilities := m.Capabilities()
|
||||||
|
if slices.Contains(capabilities, model.CapabilityTools) {
|
||||||
|
m.addToolPrefix()
|
||||||
|
}
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasToolPrefix checks if the completion starts with the tool prefix, ignoring whitespace
|
||||||
|
func (m *Model) HasToolPrefix(sb strings.Builder) bool {
|
||||||
|
text := strings.ReplaceAll(strings.TrimSpace(sb.String()), " ", "")
|
||||||
|
toolString := strings.ReplaceAll(strings.TrimSpace(m.ToolPrefix), " ", "")
|
||||||
|
|
||||||
|
if len(text) < len(toolString) {
|
||||||
|
return text == toolString[:len(text)]
|
||||||
|
}
|
||||||
|
return text[:len(toolString)] == toolString
|
||||||
|
}
|
||||||
|
|
||||||
|
// Figure out what's between the start of the tools block, and the json response, and use it as a marker. Usually that's
|
||||||
|
// {- if .ToolCalls}this text{ range .ToolCalls}or maybe this text{{.name}}
|
||||||
|
func (m *Model) addToolPrefix() {
|
||||||
|
// create a subtree from the node that ranges over .ToolCalls
|
||||||
|
var previousNode parse.Node
|
||||||
|
toolCallsTemplate := m.Template.Subtree(func(node parse.Node) bool {
|
||||||
|
if rangeNode, ok := node.(*parse.RangeNode); ok {
|
||||||
|
return slices.Contains(template.Identifiers(rangeNode.Pipe), "ToolCalls")
|
||||||
|
}
|
||||||
|
previousNode = node
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
if textNode, ok := previousNode.(*parse.TextNode); ok {
|
||||||
|
m.ToolPrefix = strings.TrimSpace(textNode.String())
|
||||||
|
}
|
||||||
|
if len(m.ToolPrefix) == 0 && len(toolCallsTemplate.Root.Nodes) > 0 {
|
||||||
|
rangeNode, ok := toolCallsTemplate.Root.Nodes[0].(*parse.RangeNode)
|
||||||
|
if ok && len(rangeNode.List.Nodes) > 0 {
|
||||||
|
m.ToolPrefix = rangeNode.List.Nodes[0].String()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyModel(src, dst model.Name) error {
|
func CopyModel(src, dst model.Name) error {
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
@ -28,19 +29,20 @@ func readFile(t *testing.T, base, name string) *bytes.Buffer {
|
|||||||
func TestExecuteWithTools(t *testing.T) {
|
func TestExecuteWithTools(t *testing.T) {
|
||||||
p := filepath.Join("testdata", "tools")
|
p := filepath.Join("testdata", "tools")
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
model string
|
model string
|
||||||
output string
|
output string
|
||||||
ok bool
|
ok bool
|
||||||
|
wellFormed bool
|
||||||
}{
|
}{
|
||||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true, true},
|
||||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||||
|
|
||||||
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
|
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true, false},
|
||||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false},
|
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false, false},
|
||||||
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
|
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
|
||||||
|
|
||||||
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true, false},
|
||||||
{"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
{"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false, false},
|
||||||
{"command-r-plus", "Action: ```json" + `
|
{"command-r-plus", "Action: ```json" + `
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
@ -58,16 +60,17 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
` + "```", true},
|
` + "```", true, true},
|
||||||
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false, false},
|
||||||
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true, true},
|
||||||
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false, false},
|
||||||
{"llama3-groq-tool-use", `<tool_call>
|
{"llama3-groq-tool-use", `<tool_call>
|
||||||
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
|
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
|
||||||
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
|
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
|
||||||
</tool_call>`, true},
|
</tool_call>`, true, true},
|
||||||
{"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true},
|
{"xlam", `### Response:
|
||||||
{"nemotron", `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]} </toolcall>`, true},
|
{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true, true},
|
||||||
|
{"nemotron", `<toolcall> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]} </toolcall>`, true, true},
|
||||||
}
|
}
|
||||||
|
|
||||||
var tools []api.Tool
|
var tools []api.Tool
|
||||||
@ -119,6 +122,21 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("prefix", func(t *testing.T) {
|
||||||
|
m := &Model{Template: tmpl}
|
||||||
|
m.addToolPrefix()
|
||||||
|
|
||||||
|
if tt.wellFormed {
|
||||||
|
if len(m.ToolPrefix) == 0 {
|
||||||
|
t.Fatalf("No tool prefix detected")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasPrefix(strings.TrimSpace(tt.output), m.ToolPrefix) {
|
||||||
|
t.Fatalf("incorrect tool prefix: \"%s\", \"%s\"", m.ToolPrefix, tt.output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("parse", func(t *testing.T) {
|
t.Run("parse", func(t *testing.T) {
|
||||||
m := &Model{Template: tmpl}
|
m := &Model{Template: tmpl}
|
||||||
actual, ok := m.parseToolCalls(tt.output)
|
actual, ok := m.parseToolCalls(tt.output)
|
||||||
@ -177,3 +195,64 @@ func TestParseObjects(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAddToolPrefix(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
template string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "prefix_from_previous_text_node",
|
||||||
|
template: `Previous text node{{- range .ToolCalls}}{{.name}}{{end}}`,
|
||||||
|
want: "Previous text node",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "prefix_from_range_node",
|
||||||
|
template: `{{- range .ToolCalls}}[TOOL_CALLS]{{.name}}{{end}}`,
|
||||||
|
want: "[TOOL_CALLS]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "prefix_with_extra_whitespace",
|
||||||
|
template: ` Previous text with spaces {{- range .ToolCalls}}{{.name}}{{end}}`,
|
||||||
|
want: "Previous text with spaces",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "prefix_with_newlines",
|
||||||
|
template: "First line\nSecond line\n{{- range .ToolCalls}}{{.name}}{{end}}",
|
||||||
|
want: "First line\nSecond line",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool_calls_json_template",
|
||||||
|
template: `{{ if .Content }}{{ .Content }}{{- else if .ToolCalls }}<tool_call>
|
||||||
|
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}{{ end }}</tool_call>
|
||||||
|
{{ end }}`,
|
||||||
|
want: `<tool_call>`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mistral_tool_calls_template",
|
||||||
|
template: `{{- if .Content }} {{ .Content }}
|
||||||
|
{{- else if .ToolCalls }}[TOOL_CALLS] [
|
||||||
|
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||||
|
{{- end }}]
|
||||||
|
{{- end }}</s>`,
|
||||||
|
want: "[TOOL_CALLS] [",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tmpl, err := template.Parse(tt.template)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m := &Model{Template: tmpl}
|
||||||
|
m.addToolPrefix()
|
||||||
|
|
||||||
|
if m.ToolPrefix != tt.want {
|
||||||
|
t.Errorf("incorrect tool prefix:\ngot: %q\nwant: %q", m.ToolPrefix, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1526,6 +1526,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
defer close(ch)
|
defer close(ch)
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
var toolCallIndex int = 0
|
var toolCallIndex int = 0
|
||||||
|
var mightBeTools bool = true
|
||||||
|
buf := make([]api.ChatResponse, 0)
|
||||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Images: images,
|
Images: images,
|
||||||
@ -1551,18 +1553,29 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: tool call checking and filtering should be moved outside of this callback once streaming
|
// If we know we're not streaming
|
||||||
// however this was a simple change for now without reworking streaming logic of this (and other)
|
if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 || !mightBeTools {
|
||||||
// handlers
|
|
||||||
if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 {
|
|
||||||
ch <- res
|
ch <- res
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sb.WriteString(r.Content)
|
||||||
|
|
||||||
|
// Buffer up responses while we're unsure whether to stream.
|
||||||
|
buf = append(buf, res)
|
||||||
|
|
||||||
|
// not a tools response, continue streaming.
|
||||||
|
if !m.HasToolPrefix(sb) {
|
||||||
|
mightBeTools = false
|
||||||
|
for _, item := range buf {
|
||||||
|
ch <- item
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Streaming tool calls:
|
// Streaming tool calls:
|
||||||
// If tools are recognized, use a flag to track the sending of a tool downstream
|
// If tools are recognized, use a flag to track the sending of a tool downstream
|
||||||
// This ensures that content is cleared from the message on the last chunk sent
|
// This ensures that content is cleared from the message on the last chunk sent
|
||||||
sb.WriteString(r.Content)
|
|
||||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||||
res.Message.ToolCalls = toolCalls
|
res.Message.ToolCalls = toolCalls
|
||||||
for i := range toolCalls {
|
for i := range toolCalls {
|
||||||
@ -1573,8 +1586,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
sb.Reset()
|
sb.Reset()
|
||||||
ch <- res
|
ch <- res
|
||||||
return
|
return
|
||||||
|
} else {
|
||||||
|
if !strings.HasPrefix(sb.String(), "{") {
|
||||||
|
ch <- res
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Done {
|
if r.Done {
|
||||||
// Send any remaining content if no tool calls were detected
|
// Send any remaining content if no tool calls were detected
|
||||||
if toolCallIndex == 0 {
|
if toolCallIndex == 0 {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user