Compare commits

..

2 Commits

Author SHA1 Message Date
ParthSareen
3bc9d42e2e rebase + fix tests 2025-04-03 17:31:21 -07:00
ParthSareen
4053c489b4 server: enable content streaming with tools 2025-04-03 17:09:59 -07:00
25 changed files with 424 additions and 502 deletions

View File

@@ -291,7 +291,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file) - [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
- [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui) - [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
- [Ollamac](https://github.com/kevinhermawan/Ollamac) - [Ollamac](https://github.com/kevinhermawan/Ollamac)
- [big-AGI](https://github.com/enricoros/big-AGI) - [big-AGI](https://github.com/enricoros/big-AGI/blob/main/docs/config-local-ollama.md)
- [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core) - [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
- [Amica](https://github.com/semperai/amica) - [Amica](https://github.com/semperai/amica)
- [chatd](https://github.com/BruceMacD/chatd) - [chatd](https://github.com/BruceMacD/chatd)
@@ -348,7 +348,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery) - [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot and Ollama4j - [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot and Ollama4j
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models. - [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VSCode extension for multi-file/whole-repo coding - [Claude Dev](https://github.com/saoudrizwan/claude-dev) - VSCode extension for multi-file/whole-repo coding
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support) - [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption) - [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library) - [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
@@ -440,7 +440,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis. - [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis.
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama. - [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull and download models from Ollama Registry in your terminal. - [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull and download models from Ollama Registry in your terminal.
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
### Apple Vision Pro ### Apple Vision Pro

View File

@@ -166,48 +166,6 @@ type Tool struct {
Function ToolFunction `json:"function"` Function ToolFunction `json:"function"`
} }
// PropertyType can be either a string or an array of strings
type PropertyType []string
// UnmarshalJSON implements the json.Unmarshaler interface
func (pt *PropertyType) UnmarshalJSON(data []byte) error {
// Try to unmarshal as a string first
var s string
if err := json.Unmarshal(data, &s); err == nil {
*pt = []string{s}
return nil
}
// If that fails, try to unmarshal as an array of strings
var a []string
if err := json.Unmarshal(data, &a); err != nil {
return err
}
*pt = a
return nil
}
// MarshalJSON implements the json.Marshaler interface
func (pt PropertyType) MarshalJSON() ([]byte, error) {
if len(pt) == 1 {
// If there's only one type, marshal as a string
return json.Marshal(pt[0])
}
// Otherwise marshal as an array
return json.Marshal([]string(pt))
}
// String returns a string representation of the PropertyType
func (pt PropertyType) String() string {
if len(pt) == 0 {
return ""
}
if len(pt) == 1 {
return pt[0]
}
return fmt.Sprintf("%v", []string(pt))
}
type ToolFunction struct { type ToolFunction struct {
Name string `json:"name"` Name string `json:"name"`
Description string `json:"description"` Description string `json:"description"`
@@ -215,9 +173,9 @@ type ToolFunction struct {
Type string `json:"type"` Type string `json:"type"`
Required []string `json:"required"` Required []string `json:"required"`
Properties map[string]struct { Properties map[string]struct {
Type PropertyType `json:"type"` Type string `json:"type"`
Description string `json:"description"` Description string `json:"description"`
Enum []any `json:"enum,omitempty"` Enum []string `json:"enum,omitempty"`
} `json:"properties"` } `json:"properties"`
} `json:"parameters"` } `json:"parameters"`
} }

View File

@@ -231,144 +231,3 @@ func TestMessage_UnmarshalJSON(t *testing.T) {
} }
} }
} }
func TestToolFunction_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
input string
wantErr string
}{
{
name: "valid enum with same types",
input: `{
"name": "test",
"description": "test function",
"parameters": {
"type": "object",
"required": ["test"],
"properties": {
"test": {
"type": "string",
"description": "test prop",
"enum": ["a", "b", "c"]
}
}
}
}`,
wantErr: "",
},
{
name: "empty enum array",
input: `{
"name": "test",
"description": "test function",
"parameters": {
"type": "object",
"required": ["test"],
"properties": {
"test": {
"type": "string",
"description": "test prop",
"enum": []
}
}
}
}`,
wantErr: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var tf ToolFunction
err := json.Unmarshal([]byte(tt.input), &tf)
if tt.wantErr != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.wantErr)
} else {
require.NoError(t, err)
}
})
}
}
func TestPropertyType_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
input string
expected PropertyType
}{
{
name: "string type",
input: `"string"`,
expected: PropertyType{"string"},
},
{
name: "array of types",
input: `["string", "number"]`,
expected: PropertyType{"string", "number"},
},
{
name: "array with single type",
input: `["string"]`,
expected: PropertyType{"string"},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var pt PropertyType
if err := json.Unmarshal([]byte(test.input), &pt); err != nil {
t.Errorf("Unexpected error: %v", err)
}
if len(pt) != len(test.expected) {
t.Errorf("Length mismatch: got %v, expected %v", len(pt), len(test.expected))
}
for i, v := range pt {
if v != test.expected[i] {
t.Errorf("Value mismatch at index %d: got %v, expected %v", i, v, test.expected[i])
}
}
})
}
}
func TestPropertyType_MarshalJSON(t *testing.T) {
tests := []struct {
name string
input PropertyType
expected string
}{
{
name: "single type",
input: PropertyType{"string"},
expected: `"string"`,
},
{
name: "multiple types",
input: PropertyType{"string", "number"},
expected: `["string","number"]`,
},
{
name: "empty type",
input: PropertyType{},
expected: `[]`,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
data, err := json.Marshal(test.input)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if string(data) != test.expected {
t.Errorf("Marshaled data mismatch: got %v, expected %v", string(data), test.expected)
}
})
}
}

View File

@@ -1381,6 +1381,7 @@ func NewCLI() *cobra.Command {
envVars["OLLAMA_NOPRUNE"], envVars["OLLAMA_NOPRUNE"],
envVars["OLLAMA_ORIGINS"], envVars["OLLAMA_ORIGINS"],
envVars["OLLAMA_SCHED_SPREAD"], envVars["OLLAMA_SCHED_SPREAD"],
envVars["OLLAMA_TMPDIR"],
envVars["OLLAMA_FLASH_ATTENTION"], envVars["OLLAMA_FLASH_ATTENTION"],
envVars["OLLAMA_KV_CACHE_TYPE"], envVars["OLLAMA_KV_CACHE_TYPE"],
envVars["OLLAMA_LLM_LIBRARY"], envVars["OLLAMA_LLM_LIBRARY"],

View File

@@ -26,6 +26,7 @@ When you run Ollama on **Windows**, there are a few different locations. You can
- `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log` - `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log`
- `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH) - `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH)
- `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored - `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored
- `explorer %TEMP%` where temporary executable files are stored in one or more `ollama*` directories
To enable additional debug logging to help troubleshoot problems, first **Quit the running app from the tray menu** then in a powershell terminal To enable additional debug logging to help troubleshoot problems, first **Quit the running app from the tray menu** then in a powershell terminal
@@ -68,6 +69,10 @@ If you run into problems on Linux and want to install an older version, or you'd
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.5.7 sh curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.5.7 sh
``` ```
## Linux tmp noexec
If your system is configured with the "noexec" flag where Ollama stores its temporary executable files, you can specify an alternate location by setting OLLAMA_TMPDIR to a location writable by the user ollama runs as. For example OLLAMA_TMPDIR=/usr/share/ollama/
## Linux docker ## Linux docker
If Ollama initially works on the GPU in a docker container, but then switches to running on CPU after some period of time with errors in the server log reporting GPU discovery failures, this can be resolved by disabling systemd cgroup management in Docker. Edit `/etc/docker/daemon.json` on the host and add `"exec-opts": ["native.cgroupdriver=cgroupfs"]` to the docker configuration. If Ollama initially works on the GPU in a docker container, but then switches to running on CPU after some period of time with errors in the server log reporting GPU discovery failures, this can be resolved by disabling systemd cgroup management in Docker. Edit `/etc/docker/daemon.json` on the host and add `"exec-opts": ["native.cgroupdriver=cgroupfs"]` to the docker configuration.

View File

@@ -62,6 +62,7 @@ the explorer window by hitting `<Ctrl>+R` and type in:
- *upgrade.log* contains log output for upgrades - *upgrade.log* contains log output for upgrades
- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH) - `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
- `explorer %HOMEPATH%\.ollama` contains models and configuration - `explorer %HOMEPATH%\.ollama` contains models and configuration
- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories
## Uninstall ## Uninstall

View File

@@ -52,8 +52,8 @@ func TestMaxQueue(t *testing.T) {
embedCtx := ctx embedCtx := ctx
var genwg sync.WaitGroup var genwg sync.WaitGroup
genwg.Add(1)
go func() { go func() {
genwg.Add(1)
defer genwg.Done() defer genwg.Done()
slog.Info("Starting generate request") slog.Info("Starting generate request")
DoGenerate(ctx, t, client, req, resp, 45*time.Second, 5*time.Second) DoGenerate(ctx, t, client, req, resp, 45*time.Second, 5*time.Second)
@@ -71,8 +71,8 @@ func TestMaxQueue(t *testing.T) {
counterMu := sync.Mutex{} counterMu := sync.Mutex{}
var embedwg sync.WaitGroup var embedwg sync.WaitGroup
for i := 0; i < threadCount; i++ { for i := 0; i < threadCount; i++ {
embedwg.Add(1)
go func(i int) { go func(i int) {
embedwg.Add(1)
defer embedwg.Done() defer embedwg.Done()
slog.Info("embed started", "id", i) slog.Info("embed started", "id", i)
embedReq := api.EmbeddingRequest{ embedReq := api.EmbeddingRequest{

View File

@@ -56,9 +56,8 @@ type Cache interface {
// StartForward is called before the start of the model's forward pass. // StartForward is called before the start of the model's forward pass.
// For each token in the coming batch, there must be a corresponding // For each token in the coming batch, there must be a corresponding
// entry in positions and seqs. reserve is to preallocate memory // entry in positions and seqs.
// without actually storing data in the cache. StartForward(ctx ml.Context, batch input.Batch) error
StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq // CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix(srcSeq, dstSeq int, len int32) CopyPrefix(srcSeq, dstSeq int, len int32)

View File

@@ -146,13 +146,12 @@ func (c *Causal) Close() {
} }
} }
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error {
c.curBatchSize = len(batch.Positions) c.curBatchSize = len(batch.Positions)
c.curSequences = batch.Sequences c.curSequences = batch.Sequences
c.curPositions = batch.Positions c.curPositions = batch.Positions
c.opts.Except = nil c.opts.Except = nil
if !reserve {
c.updateSlidingWindow() c.updateSlidingWindow()
var err error var err error
@@ -191,15 +190,7 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
} }
c.cellRanges[seq] = seqRange c.cellRanges[seq] = seqRange
} }
} else {
// If we are reserving memory, don't update any of the cache metadata but set the size
// to the worst case.
c.curLoc = 0
c.curCellRange.min = 0
c.curCellRange.max = len(c.cells) - 1
}
var err error
c.curMask, err = c.buildMask(ctx) c.curMask, err = c.buildMask(ctx)
return err return err

View File

@@ -5,6 +5,7 @@ import (
"slices" "slices"
"testing" "testing"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input" "github.com/ollama/ollama/model/input"
) )
@@ -280,7 +281,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
context := backend.NewContext() context := backend.NewContext()
defer context.Close() defer context.Close()
err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false) err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs})
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -314,7 +315,7 @@ func TestCanResume(t *testing.T) {
err := cache.StartForward(context, input.Batch{ err := cache.StartForward(context, input.Batch{
Positions: []int32{0, 1, 2, 3}, Positions: []int32{0, 1, 2, 3},
Sequences: []int{0, 0, 0, 0}, Sequences: []int{0, 0, 0, 0},
}, false) })
if err != nil { if err != nil {
t.Fatalf("StartForward failed: %v", err) t.Fatalf("StartForward failed: %v", err)
} }
@@ -341,7 +342,7 @@ func TestCanResume(t *testing.T) {
err = cache.StartForward(context, input.Batch{ err = cache.StartForward(context, input.Batch{
Positions: []int32{4, 5}, Positions: []int32{4, 5},
Sequences: []int{0, 0}, Sequences: []int{0, 0},
}, false) })
if err != nil { if err != nil {
t.Fatalf("StartForward failed: %v", err) t.Fatalf("StartForward failed: %v", err)
} }
@@ -371,8 +372,14 @@ func TestCanResume(t *testing.T) {
} }
} }
type testBackend struct { type testBackend struct{}
ml.Backend
func (b *testBackend) Config() fs.Config {
panic("not implemented")
}
func (b *testBackend) Get(name string) ml.Tensor {
panic("not implemented")
} }
func (b *testBackend) NewContext() ml.Context { func (b *testBackend) NewContext() ml.Context {
@@ -383,10 +390,12 @@ func (b *testBackend) NewContextSize(int) ml.Context {
return &testContext{} return &testContext{}
} }
type testContext struct { func (b *testBackend) SystemInfo() string {
ml.Context return "not implemented"
} }
type testContext struct{}
func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor { func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
total := 0 total := 0
@@ -431,8 +440,6 @@ func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
func (c *testContext) Compute(...ml.Tensor) {} func (c *testContext) Compute(...ml.Tensor) {}
func (c *testContext) Reserve() error { return nil }
func (c *testContext) MaxGraphNodes() int { func (c *testContext) MaxGraphNodes() int {
return 10 return 10
} }
@@ -440,8 +447,6 @@ func (c *testContext) MaxGraphNodes() int {
func (c *testContext) Close() {} func (c *testContext) Close() {}
type testTensor struct { type testTensor struct {
ml.Tensor
dtype ml.DType dtype ml.DType
elementSize int elementSize int
data []float32 data []float32
@@ -469,6 +474,10 @@ func (t *testTensor) DType() ml.DType {
return t.dtype return t.dtype
} }
func (t *testTensor) Bytes() []byte {
panic("not implemented")
}
func (t *testTensor) Floats() []float32 { func (t *testTensor) Floats() []float32 {
out := make([]float32, len(t.data)) out := make([]float32, len(t.data))
copy(out, t.data) copy(out, t.data)
@@ -493,6 +502,64 @@ func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return out return out
} }
func (t *testTensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Softmax(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) LayerNorm(ctx ml.Context, weight, bias ml.Tensor, eps float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) RMSNorm(ctx ml.Context, weight ml.Tensor, eps float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) IM2Col(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Cos(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) Sin(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) GELU(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) SILU(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
offset /= t.elementSize offset /= t.elementSize
@@ -515,7 +582,43 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
return view return view
} }
func (t *testTensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Repeat(ctx ml.Context, dim, n int) ml.Tensor { panic("not implemented") }
func (t *testTensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor { func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
copy(t2.(*testTensor).data, t.data) copy(t2.(*testTensor).data, t.data)
return nil return nil
} }
func (t *testTensor) Duplicate(ctx ml.Context) ml.Tensor { panic("not implemented") }

View File

@@ -27,11 +27,6 @@ type EncoderCache struct {
// anything will be stored) // anything will be stored)
curPos int32 curPos int32
// curReserve indicates that this forward pass is only for
// memory reservation and we should not update our metadata
// based on it.
curReserve bool
// ** cache metadata ** // ** cache metadata **
// was something stored in the cache? // was something stored in the cache?
@@ -88,14 +83,12 @@ func (c *EncoderCache) Close() {
} }
} }
func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch) error {
// We work with the most recent image // We work with the most recent image
if len(batch.Multimodal) > 0 { if len(batch.Multimodal) > 0 {
c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index] c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
} }
c.curReserve = reserve
return nil return nil
} }
@@ -112,10 +105,8 @@ func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
} }
func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) { func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
if !c.curReserve {
c.encoderPos = c.curPos c.encoderPos = c.curPos
c.encoderCached = true c.encoderCached = true
}
if c.config.PermutedV { if c.config.PermutedV {
value = value.Permute(ctx, 1, 2, 0, 3) value = value.Permute(ctx, 1, 2, 0, 3)

View File

@@ -41,9 +41,9 @@ func (c *WrapperCache) Close() {
} }
} }
func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch) error {
for i, cache := range c.caches { for i, cache := range c.caches {
err := cache.StartForward(ctx, batch, reserve) err := cache.StartForward(ctx, batch)
if err != nil { if err != nil {
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
for j := i - 1; j >= 0; j-- { for j := i - 1; j >= 0; j-- {

View File

@@ -97,13 +97,6 @@ type Context interface {
Forward(...Tensor) Context Forward(...Tensor) Context
Compute(...Tensor) Compute(...Tensor)
// Reserve is analogous to Compute but rather than executing a
// graph, simply preallocates memory. Typically called with a
// worst case graph to ensure all resources are available for
// for future inference.
Reserve() error
MaxGraphNodes() int MaxGraphNodes() int
Close() Close()

View File

@@ -10,7 +10,6 @@ import "C"
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
@@ -44,11 +43,7 @@ func devices() []*C.struct_ggml_backend_device {
type Backend struct { type Backend struct {
meta *fsggml.GGML meta *fsggml.GGML
sched *C.struct_ggml_backend_sched sched *C.struct_ggml_backend_sched
schedBackends []*C.struct_ggml_backend
schedBufts []*C.struct_ggml_backend_buffer_type
tensors map[string]*C.struct_ggml_tensor tensors map[string]*C.struct_ggml_tensor
// input is the backend used for inputs // input is the backend used for inputs
@@ -286,10 +281,6 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
} }
b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt) b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt)
if b == nil {
return nil, fmt.Errorf("unable to allocate memory from device %v for model weights", C.GoString(C.ggml_backend_buft_name(bt)))
}
C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS) C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS)
bbs[c] = b bbs[c] = b
} }
@@ -328,14 +319,7 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
tts[i] = tt tts[i] = tt
} }
// Create a new FD for each goroutine so that each FD is read sequentially, rather than sr := io.NewSectionReader(r, int64(meta.Tensors().Offset+t.Offset), int64(t.Size()))
// seeking around within an FD shared between all goroutines.
file, err := os.Open(r.Name())
if err != nil {
return err
}
defer file.Close()
sr := io.NewSectionReader(file, int64(meta.Tensors().Offset+t.Offset), int64(t.Size()))
bts := make([]byte, 128*format.KibiByte) bts := make([]byte, 128*format.KibiByte)
var s uint64 var s uint64
@@ -394,6 +378,8 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
schedBackends = append(schedBackends, b) schedBackends = append(schedBackends, b)
schedBufts = append(schedBufts, bt) schedBufts = append(schedBufts, bt)
slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(b)), "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
if C.ggml_backend_is_cpu(b) { if C.ggml_backend_is_cpu(b) {
// set number of threads for cpu backend // set number of threads for cpu backend
C.ggml_backend_cpu_set_n_threads(b, C.int(Threads(params.NumThreads))) C.ggml_backend_cpu_set_n_threads(b, C.int(Threads(params.NumThreads)))
@@ -412,8 +398,6 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
C.size_t(maxGraphNodes), C.size_t(maxGraphNodes),
C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)), C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
), ),
schedBackends: schedBackends,
schedBufts: schedBufts,
input: deviceBufferTypes[input.d], input: deviceBufferTypes[input.d],
layers: func() map[int]*C.struct_ggml_backend_buffer_type { layers: func() map[int]*C.struct_ggml_backend_buffer_type {
m := make(map[int]*C.struct_ggml_backend_buffer_type) m := make(map[int]*C.struct_ggml_backend_buffer_type)
@@ -539,24 +523,6 @@ func (c Context) Compute(tensors ...ml.Tensor) {
} }
} }
func (c Context) Reserve() error {
if !C.ggml_backend_sched_reserve(c.b.sched, c.graph) {
C.ggml_backend_sched_reset(c.b.sched)
return errors.New("failed to reserve graph")
}
slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched))
for i := range c.b.schedBackends {
size := C.ggml_backend_sched_get_buffer_size(c.b.sched, c.b.schedBackends[i])
slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])),
"size", format.HumanBytes2(uint64(size)))
}
C.ggml_backend_sched_reset(c.b.sched)
return nil
}
func (c Context) MaxGraphNodes() int { func (c Context) MaxGraphNodes() int {
return c.maxGraphNodes return c.maxGraphNodes
} }
@@ -574,9 +540,9 @@ func pad(length, pad C.size_t) C.size_t {
return ((length + pad - 1) / pad) * pad return ((length + pad - 1) / pad) * pad
} }
func (c Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) { func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
if c.buft == nil { if c.buft == nil {
panic("set Input or Layer before creating tensors") panic("set Input, Output, or Layer before creating tensors")
} }
var cdtype uint32 var cdtype uint32
@@ -597,7 +563,7 @@ func (c Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
if len(shape) < 1 || shape[0] == 0 { if len(shape) < 1 || shape[0] == 0 {
var shape C.int64_t = 0 var shape C.int64_t = 0
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}, nil return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}
} else if len(shape) > 4 { } else if len(shape) > 4 {
panic("unsupported number of dimensions") panic("unsupported number of dimensions")
} }
@@ -611,29 +577,16 @@ func (c Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape)) t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape))
size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft)) size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft))
b := C.ggml_backend_buft_alloc_buffer(c.buft, size) b := C.ggml_backend_buft_alloc_buffer(c.buft, size)
if b == nil {
return nil, fmt.Errorf("unable to allocate %v from device %v for new tensor", format.HumanBytes2(uint64(size)), C.GoString(C.ggml_backend_buft_name(c.buft)))
}
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b)) C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
return &Tensor{b: c.b, t: t}, nil return &Tensor{b: c.b, t: t}
} }
func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor { func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
t, err := c.newTensor(dtype, shape) return c.newTensor(dtype, shape)
if err != nil {
panic(err)
}
return t
} }
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
t, err := c.newTensor(dtype, shape) t := c.newTensor(dtype, shape)
if err != nil {
panic(err)
}
C.ggml_set_zero(t.(*Tensor).t) C.ggml_set_zero(t.(*Tensor).t)
return t return t
} }
@@ -661,11 +614,7 @@ func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
return nil, err return nil, err
} }
t, err := c.newTensor(ml.DTypeF32, shape) t := c.newTensor(ml.DTypeF32, shape)
if err != nil {
return nil, err
}
if len(s) > 0 { if len(s) > 0 {
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
} }
@@ -678,11 +627,7 @@ func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
return nil, err return nil, err
} }
t, err := c.newTensor(ml.DTypeI32, shape) t := c.newTensor(ml.DTypeI32, shape)
if err != nil {
return nil, err
}
if len(s) > 0 { if len(s) > 0 {
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
} }

View File

@@ -299,7 +299,7 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten
cache := m.Config().Cache cache := m.Config().Cache
if cache != nil { if cache != nil {
err := cache.StartForward(ctx, batch, false) err := cache.StartForward(ctx, batch)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -283,25 +283,25 @@ func TestChatMiddleware(t *testing.T) {
Type string `json:"type"` Type string `json:"type"`
Required []string `json:"required"` Required []string `json:"required"`
Properties map[string]struct { Properties map[string]struct {
Type api.PropertyType `json:"type"` Type string `json:"type"`
Description string `json:"description"` Description string `json:"description"`
Enum []any `json:"enum,omitempty"` Enum []string `json:"enum,omitempty"`
} `json:"properties"` } `json:"properties"`
}{ }{
Type: "object", Type: "object",
Required: []string{"location"}, Required: []string{"location"},
Properties: map[string]struct { Properties: map[string]struct {
Type api.PropertyType `json:"type"` Type string `json:"type"`
Description string `json:"description"` Description string `json:"description"`
Enum []any `json:"enum,omitempty"` Enum []string `json:"enum,omitempty"`
}{ }{
"location": { "location": {
Type: api.PropertyType{"string"}, Type: "string",
Description: "The city and state", Description: "The city and state",
}, },
"unit": { "unit": {
Type: api.PropertyType{"string"}, Type: "string",
Enum: []any{"celsius", "fahrenheit"}, Enum: []string{"celsius", "fahrenheit"},
}, },
}, },
}, },

View File

@@ -11,13 +11,10 @@ import (
"os" "os"
"os/user" "os/user"
"path/filepath" "path/filepath"
"runtime"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
"sync"
"golang.org/x/sync/errgroup"
"golang.org/x/text/encoding/unicode" "golang.org/x/text/encoding/unicode"
"golang.org/x/text/transform" "golang.org/x/text/transform"
@@ -147,26 +144,13 @@ func fileDigestMap(path string) (map[string]string, error) {
files = []string{path} files = []string{path}
} }
var mu sync.Mutex
var g errgroup.Group
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
for _, f := range files { for _, f := range files {
g.Go(func() error {
digest, err := digestForFile(f) digest, err := digestForFile(f)
if err != nil { if err != nil {
return err
}
mu.Lock()
defer mu.Unlock()
fl[f] = digest
return nil
})
}
if err := g.Wait(); err != nil {
return nil, err return nil, err
} }
fl[f] = digest
}
return fl, nil return fl, nil
} }

View File

@@ -448,7 +448,7 @@ func (m *mockCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor) {} func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor) {}
func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {} func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {}
func (m *mockCache) Close() {} func (m *mockCache) Close() {}
func (m *mockCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { return nil } func (m *mockCache) StartForward(ctx ml.Context, batch input.Batch) error { return nil }
func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {} func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {}
func (m *mockCache) SetConfig(ml.CacheConfig) {} func (m *mockCache) SetConfig(ml.CacheConfig) {}
func (m *mockCache) CanResume(seq int, pos int32) bool { return true } func (m *mockCache) CanResume(seq int, pos int32) bool { return true }

View File

@@ -728,51 +728,6 @@ func (m *multiLPath) String() string {
return strings.Join(*m, ", ") return strings.Join(*m, ", ")
} }
func (s *Server) reserveWorstCaseGraph() error {
ctx := s.model.Backend().NewContext()
defer ctx.Close()
var batch input.Batch
inputs := make([]int32, s.batchSize)
batch.Positions = make([]int32, len(inputs))
batch.Sequences = make([]int, len(inputs))
for i := range inputs {
batch.Positions[i] = int32(i)
}
batch.Outputs = make([]int32, s.parallel)
for i := range batch.Outputs {
batch.Outputs[i] = int32(i)
}
var err error
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
if err != nil {
return err
}
cache := s.model.Config().Cache
if cache != nil {
err := cache.StartForward(ctx, batch, true)
if err != nil {
return err
}
}
t, err := s.model.Forward(ctx, batch)
if err != nil {
return err
}
err = ctx.Forward(t).Reserve()
if err != nil {
return err
}
return nil
}
func (s *Server) loadModel( func (s *Server) loadModel(
ctx context.Context, ctx context.Context,
mpath string, mpath string,
@@ -810,11 +765,6 @@ func (s *Server) loadModel(
s.seqs = make([]*Sequence, s.parallel) s.seqs = make([]*Sequence, s.parallel)
s.seqsSem = semaphore.NewWeighted(int64(s.parallel)) s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
err = s.reserveWorstCaseGraph()
if err != nil {
panic(err)
}
s.status = llm.ServerStatusReady s.status = llm.ServerStatusReady
s.ready.Done() s.ready.Done()
} }

View File

@@ -497,8 +497,12 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
return nil, err return nil, err
} }
var offset int64
for offset < stat.Size() {
f, n, err := ggml.Decode(blob, 0) f, n, err := ggml.Decode(blob, 0)
if err != nil { if errors.Is(err, io.EOF) {
break
} else if err != nil {
return nil, err return nil, err
} }
@@ -510,7 +514,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
} }
var layer Layer var layer Layer
if digest != "" && n == stat.Size() { if digest != "" && n == stat.Size() && offset == 0 {
layer, err = NewLayerFromLayer(digest, mediatype, blob.Name()) layer, err = NewLayerFromLayer(digest, mediatype, blob.Name())
if err != nil { if err != nil {
slog.Debug("could not create new layer from layer", "error", err) slog.Debug("could not create new layer from layer", "error", err)
@@ -520,13 +524,15 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
// Fallback to creating layer from file copy (either NewLayerFromLayer failed, or digest empty/n != stat.Size()) // Fallback to creating layer from file copy (either NewLayerFromLayer failed, or digest empty/n != stat.Size())
if layer.Digest == "" { if layer.Digest == "" {
layer, err = NewLayer(io.NewSectionReader(blob, 0, n), mediatype) layer, err = NewLayer(io.NewSectionReader(blob, offset, n), mediatype)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
layers = append(layers, &layerGGML{layer, f}) layers = append(layers, &layerGGML{layer, f})
offset = n
}
return detectChatTemplate(layers) return detectChatTemplate(layers)
} }

View File

@@ -20,6 +20,7 @@ import (
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
"text/template/parse"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
@@ -62,6 +63,7 @@ type Model struct {
Digest string Digest string
Options map[string]any Options map[string]any
Messages []api.Message Messages []api.Message
ToolPrefix string
Template *template.Template Template *template.Template
} }
@@ -260,7 +262,7 @@ func GetModel(name string) (*Model, error) {
return nil, err return nil, err
} }
model := &Model{ m := &Model{
Name: mp.GetFullTagname(), Name: mp.GetFullTagname(),
ShortName: mp.GetShortTagname(), ShortName: mp.GetShortTagname(),
Digest: digest, Digest: digest,
@@ -279,7 +281,7 @@ func GetModel(name string) (*Model, error) {
} }
defer configFile.Close() defer configFile.Close()
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil { if err := json.NewDecoder(configFile).Decode(&m.Config); err != nil {
return nil, err return nil, err
} }
} }
@@ -292,16 +294,16 @@ func GetModel(name string) (*Model, error) {
switch layer.MediaType { switch layer.MediaType {
case "application/vnd.ollama.image.model": case "application/vnd.ollama.image.model":
model.ModelPath = filename m.ModelPath = filename
model.ParentModel = layer.From m.ParentModel = layer.From
case "application/vnd.ollama.image.embed": case "application/vnd.ollama.image.embed":
// Deprecated in versions > 0.1.2 // Deprecated in versions > 0.1.2
// TODO: remove this warning in a future version // TODO: remove this warning in a future version
slog.Info("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.") slog.Info("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
case "application/vnd.ollama.image.adapter": case "application/vnd.ollama.image.adapter":
model.AdapterPaths = append(model.AdapterPaths, filename) m.AdapterPaths = append(m.AdapterPaths, filename)
case "application/vnd.ollama.image.projector": case "application/vnd.ollama.image.projector":
model.ProjectorPaths = append(model.ProjectorPaths, filename) m.ProjectorPaths = append(m.ProjectorPaths, filename)
case "application/vnd.ollama.image.prompt", case "application/vnd.ollama.image.prompt",
"application/vnd.ollama.image.template": "application/vnd.ollama.image.template":
bts, err := os.ReadFile(filename) bts, err := os.ReadFile(filename)
@@ -309,7 +311,7 @@ func GetModel(name string) (*Model, error) {
return nil, err return nil, err
} }
model.Template, err = template.Parse(string(bts)) m.Template, err = template.Parse(string(bts))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -319,7 +321,7 @@ func GetModel(name string) (*Model, error) {
return nil, err return nil, err
} }
model.System = string(bts) m.System = string(bts)
case "application/vnd.ollama.image.params": case "application/vnd.ollama.image.params":
params, err := os.Open(filename) params, err := os.Open(filename)
if err != nil { if err != nil {
@@ -328,7 +330,7 @@ func GetModel(name string) (*Model, error) {
defer params.Close() defer params.Close()
// parse model options parameters into a map so that we can see which fields have been specified explicitly // parse model options parameters into a map so that we can see which fields have been specified explicitly
if err = json.NewDecoder(params).Decode(&model.Options); err != nil { if err = json.NewDecoder(params).Decode(&m.Options); err != nil {
return nil, err return nil, err
} }
case "application/vnd.ollama.image.messages": case "application/vnd.ollama.image.messages":
@@ -338,7 +340,7 @@ func GetModel(name string) (*Model, error) {
} }
defer msgs.Close() defer msgs.Close()
if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil { if err = json.NewDecoder(msgs).Decode(&m.Messages); err != nil {
return nil, err return nil, err
} }
case "application/vnd.ollama.image.license": case "application/vnd.ollama.image.license":
@@ -346,11 +348,50 @@ func GetModel(name string) (*Model, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
model.License = append(model.License, string(bts)) m.License = append(m.License, string(bts))
} }
} }
return model, nil capabilities := m.Capabilities()
if slices.Contains(capabilities, model.CapabilityTools) {
m.addToolPrefix()
}
return m, nil
}
// HasToolPrefix checks if the completion starts with the tool prefix, ignoring whitespace
func (m *Model) HasToolPrefix(sb strings.Builder) bool {
text := strings.ReplaceAll(strings.TrimSpace(sb.String()), " ", "")
toolString := strings.ReplaceAll(strings.TrimSpace(m.ToolPrefix), " ", "")
if len(text) < len(toolString) {
return text == toolString[:len(text)]
}
return text[:len(toolString)] == toolString
}
// Figure out what's between the start of the tools block, and the json response, and use it as a marker. Usually that's
// {- if .ToolCalls}this text{ range .ToolCalls}or maybe this text{{.name}}
func (m *Model) addToolPrefix() {
// create a subtree from the node that ranges over .ToolCalls
var previousNode parse.Node
toolCallsTemplate := m.Template.Subtree(func(node parse.Node) bool {
if rangeNode, ok := node.(*parse.RangeNode); ok {
return slices.Contains(template.Identifiers(rangeNode.Pipe), "ToolCalls")
}
previousNode = node
return false
})
if textNode, ok := previousNode.(*parse.TextNode); ok {
m.ToolPrefix = strings.TrimSpace(textNode.String())
}
if len(m.ToolPrefix) == 0 && len(toolCallsTemplate.Root.Nodes) > 0 {
rangeNode, ok := toolCallsTemplate.Root.Nodes[0].(*parse.RangeNode)
if ok && len(rangeNode.List.Nodes) > 0 {
m.ToolPrefix = rangeNode.List.Nodes[0].String()
}
}
} }
func CopyModel(src, dst model.Name) error { func CopyModel(src, dst model.Name) error {

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
@@ -31,16 +32,17 @@ func TestExecuteWithTools(t *testing.T) {
model string model string
output string output string
ok bool ok bool
wellFormed bool
}{ }{
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true, true},
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true}, The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true, false},
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false}, {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false, false},
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: {"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true, false},
{"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, {"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false, false},
{"command-r-plus", "Action: ```json" + ` {"command-r-plus", "Action: ```json" + `
[ [
{ {
@@ -58,16 +60,17 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
} }
} }
] ]
` + "```", true}, ` + "```", true, true},
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, {"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false, false},
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, {"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true, true},
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, {"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false, false},
{"llama3-groq-tool-use", `<tool_call> {"llama3-groq-tool-use", `<tool_call>
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
</tool_call>`, true}, </tool_call>`, true, true},
{"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true}, {"xlam", `### Response:
{"nemotron", `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]} </toolcall>`, true}, {"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true, true},
{"nemotron", `<toolcall> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]} </toolcall>`, true, true},
} }
var tools []api.Tool var tools []api.Tool
@@ -119,6 +122,21 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
} }
}) })
t.Run("prefix", func(t *testing.T) {
m := &Model{Template: tmpl}
m.addToolPrefix()
if tt.wellFormed {
if len(m.ToolPrefix) == 0 {
t.Fatalf("No tool prefix detected")
}
if !strings.HasPrefix(strings.TrimSpace(tt.output), m.ToolPrefix) {
t.Fatalf("incorrect tool prefix: \"%s\", \"%s\"", m.ToolPrefix, tt.output)
}
}
})
t.Run("parse", func(t *testing.T) { t.Run("parse", func(t *testing.T) {
m := &Model{Template: tmpl} m := &Model{Template: tmpl}
actual, ok := m.parseToolCalls(tt.output) actual, ok := m.parseToolCalls(tt.output)
@@ -177,3 +195,64 @@ func TestParseObjects(t *testing.T) {
}) })
} }
} }
func TestAddToolPrefix(t *testing.T) {
tests := []struct {
name string
template string
want string
}{
{
name: "prefix_from_previous_text_node",
template: `Previous text node{{- range .ToolCalls}}{{.name}}{{end}}`,
want: "Previous text node",
},
{
name: "prefix_from_range_node",
template: `{{- range .ToolCalls}}[TOOL_CALLS]{{.name}}{{end}}`,
want: "[TOOL_CALLS]",
},
{
name: "prefix_with_extra_whitespace",
template: ` Previous text with spaces {{- range .ToolCalls}}{{.name}}{{end}}`,
want: "Previous text with spaces",
},
{
name: "prefix_with_newlines",
template: "First line\nSecond line\n{{- range .ToolCalls}}{{.name}}{{end}}",
want: "First line\nSecond line",
},
{
name: "tool_calls_json_template",
template: `{{ if .Content }}{{ .Content }}{{- else if .ToolCalls }}<tool_call>
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}{{ end }}</tool_call>
{{ end }}`,
want: `<tool_call>`,
},
{
name: "mistral_tool_calls_template",
template: `{{- if .Content }} {{ .Content }}
{{- else if .ToolCalls }}[TOOL_CALLS] [
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{- end }}]
{{- end }}</s>`,
want: "[TOOL_CALLS] [",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpl, err := template.Parse(tt.template)
if err != nil {
t.Fatalf("failed to parse template: %v", err)
}
m := &Model{Template: tmpl}
m.addToolPrefix()
if m.ToolPrefix != tt.want {
t.Errorf("incorrect tool prefix:\ngot: %q\nwant: %q", m.ToolPrefix, tt.want)
}
})
}
}

View File

@@ -1526,6 +1526,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
defer close(ch) defer close(ch)
var sb strings.Builder var sb strings.Builder
var toolCallIndex int = 0 var toolCallIndex int = 0
var mightBeTools bool = true
buf := make([]api.ChatResponse, 0)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Images: images, Images: images,
@@ -1551,18 +1553,29 @@ func (s *Server) ChatHandler(c *gin.Context) {
res.LoadDuration = checkpointLoaded.Sub(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
} }
// TODO: tool call checking and filtering should be moved outside of this callback once streaming // If we know we're not streaming
// however this was a simple change for now without reworking streaming logic of this (and other) if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 || !mightBeTools {
// handlers
if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 {
ch <- res ch <- res
return return
} }
sb.WriteString(r.Content)
// Buffer up responses while we're unsure whether to stream.
buf = append(buf, res)
// not a tools response, continue streaming.
if !m.HasToolPrefix(sb) {
mightBeTools = false
for _, item := range buf {
ch <- item
}
return
}
// Streaming tool calls: // Streaming tool calls:
// If tools are recognized, use a flag to track the sending of a tool downstream // If tools are recognized, use a flag to track the sending of a tool downstream
// This ensures that content is cleared from the message on the last chunk sent // This ensures that content is cleared from the message on the last chunk sent
sb.WriteString(r.Content)
if toolCalls, ok := m.parseToolCalls(sb.String()); ok { if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
res.Message.ToolCalls = toolCalls res.Message.ToolCalls = toolCalls
for i := range toolCalls { for i := range toolCalls {
@@ -1573,8 +1586,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
sb.Reset() sb.Reset()
ch <- res ch <- res
return return
} else {
if !strings.HasPrefix(sb.String(), "{") {
ch <- res
return
}
} }
if r.Done { if r.Done {
// Send any remaining content if no tool calls were detected // Send any remaining content if no tool calls were detected
if toolCallIndex == 0 { if toolCallIndex == 0 {

View File

@@ -372,25 +372,25 @@ func TestGenerateChat(t *testing.T) {
Type string `json:"type"` Type string `json:"type"`
Required []string `json:"required"` Required []string `json:"required"`
Properties map[string]struct { Properties map[string]struct {
Type api.PropertyType `json:"type"` Type string `json:"type"`
Description string `json:"description"` Description string `json:"description"`
Enum []any `json:"enum,omitempty"` Enum []string `json:"enum,omitempty"`
} `json:"properties"` } `json:"properties"`
}{ }{
Type: "object", Type: "object",
Required: []string{"location"}, Required: []string{"location"},
Properties: map[string]struct { Properties: map[string]struct {
Type api.PropertyType `json:"type"` Type string `json:"type"`
Description string `json:"description"` Description string `json:"description"`
Enum []any `json:"enum,omitempty"` Enum []string `json:"enum,omitempty"`
}{ }{
"location": { "location": {
Type: api.PropertyType{"string"}, Type: "string",
Description: "The city and state", Description: "The city and state",
}, },
"unit": { "unit": {
Type: api.PropertyType{"string"}, Type: "string",
Enum: []any{"celsius", "fahrenheit"}, Enum: []string{"celsius", "fahrenheit"},
}, },
}, },
}, },
@@ -469,25 +469,25 @@ func TestGenerateChat(t *testing.T) {
Type string `json:"type"` Type string `json:"type"`
Required []string `json:"required"` Required []string `json:"required"`
Properties map[string]struct { Properties map[string]struct {
Type api.PropertyType `json:"type"` Type string `json:"type"`
Description string `json:"description"` Description string `json:"description"`
Enum []any `json:"enum,omitempty"` Enum []string `json:"enum,omitempty"`
} `json:"properties"` } `json:"properties"`
}{ }{
Type: "object", Type: "object",
Required: []string{"location"}, Required: []string{"location"},
Properties: map[string]struct { Properties: map[string]struct {
Type api.PropertyType `json:"type"` Type string `json:"type"`
Description string `json:"description"` Description string `json:"description"`
Enum []any `json:"enum,omitempty"` Enum []string `json:"enum,omitempty"`
}{ }{
"location": { "location": {
Type: api.PropertyType{"string"}, Type: "string",
Description: "The city and state", Description: "The city and state",
}, },
"unit": { "unit": {
Type: api.PropertyType{"string"}, Type: "string",
Enum: []any{"celsius", "fahrenheit"}, Enum: []string{"celsius", "fahrenheit"},
}, },
}, },
}, },