Compare commits

...

20 Commits

Author SHA1 Message Date
ParthSareen
b4cd1118ab checkpoint for vscode 2025-04-24 18:23:23 -07:00
ParthSareen
128c90d3ac checkpoint!!! 2025-04-24 16:57:54 -07:00
ParthSareen
f5872a097c checkpoint 2025-04-23 15:45:35 -07:00
ParthSareen
3ac5e0f102 model: update tool calling to use regex 2025-04-14 17:35:17 -07:00
Tom Sheffler
ef65174df2 types: include the 'items' and '$defs' fields to properly handle "array" types (#10091)
---------

Co-authored-by: Parth Sareen <parth.sareen@ollama.com>
2025-04-09 17:45:49 -07:00
Ire Gaddr
42ecb9f138 fix(scheduler): make model unload order deterministic (#10185) 2025-04-09 16:01:02 -07:00
湛露先生
5c0331fd83 Fix dockerfile. (#9855)
Signed-off-by: zhanluxianshen <zhanluxianshen@163.com>
2025-04-09 13:24:56 -07:00
CYJiang
e7019c9455 fix(integration): move waitgroup Add(1) outside goroutine to avoid potential issue (#10070)
Signed-off-by: googs1025 <googs1025@gmail.com>
2025-04-08 15:17:40 -07:00
Michael Yang
d98bfe7e70 kvcache: stub out test structs 2025-04-08 15:08:29 -07:00
Parth Sareen
6747099d71 types: add any type and validation for ToolFunction enum (#10166) 2025-04-08 15:05:38 -07:00
frob
ccc8c6777b cleanup: remove OLLAMA_TMPDIR and references to temporary executables (#10182)
* cleanup: remove OLLAMA_TMPDIR
* cleanup: ollama doesn't use temporary executables anymore

---------

Co-authored-by: Richard Lyons <frob@cloudstaff.com>
2025-04-08 15:01:39 -07:00
Jesse Gross
dbb149e6f7 ollamarunner: Preallocate worst case graph at startup
Currently, the KV cache and graph are lazily allocated as needed.
The cache is fully allocated on first use of the corresponding
layer whereas the graph grows with the size of the context.

This can be an issue if another application allocates more VRAM
after we do our calculations - Ollama will crash in the middle of
inference. If we instead allocate the maximum needed memory at
startup of the runner, we will either succeed or fail at that point
rather than at some surprising time in the future.

Currently, this only generates a worst case batch for text, which
means that vision models may get a partial allocation and continue
to lazily allocate the rest.
2025-04-08 10:01:28 -07:00
Jesse Gross
a807985e59 ggml: Check for OOM and return as Go errors
If there is a CUDA OOM, we currently don't check the return value
and will evetually segfault. This checks for the problem and generates
a Go error. At the moment, this will still result in a panic but having
the error is the first step to being able to handle it more gracefully.
2025-04-08 10:01:28 -07:00
qwerty108109
8643c4d5bf readme: fix url for big-AGI in community integrations (#10173) 2025-04-07 19:42:26 -07:00
Jonathan Hecl
b0c3aba590 readme: add GGUF-to-ollama to community integrations (#10156) 2025-04-07 16:31:45 -07:00
qwerty108109
19c0c25de8 readme: rename community integration from Claude Dev to Cline (#10168) 2025-04-07 16:27:20 -07:00
Alex Rozgo
2f723ac2d6 types: allow tool function parameters with a single type or an array of types (#9434) 2025-04-07 14:27:01 -07:00
Devon Rifkin
249fbbe52f Merge pull request #10169 from ollama/drifkin/fix-contributing-formatting
CONTRIBUTING: fix code block formatting
2025-04-07 14:02:35 -07:00
Devon Rifkin
c38680b8a1 CONTRIBUTING: fix code block formatting
There were only 3 spaces instead of 4, so the example was being considered to include html elements
2025-04-07 13:53:33 -07:00
Michael Yang
16fca86c4a digest files in parallel 2025-04-07 09:46:31 -07:00
25 changed files with 871 additions and 323 deletions

View File

@@ -51,7 +51,7 @@ see if the change were accepted.
The title should look like:
<package>: <short description>
<package>: <short description>
The package is the most affected Go package. If the change does not affect Go
code, then use the directory name instead. Changes to a single well-known

View File

@@ -104,8 +104,8 @@ COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
FROM --platform=linux/arm64 scratch AS arm64
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 lib/ollama/cuda_jetpack5
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 lib/ollama/cuda_jetpack6
COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_jetpack5
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_jetpack6
FROM scratch AS rocm
COPY --from=rocm-6 dist/lib/ollama/rocm /lib/ollama/rocm

View File

@@ -291,7 +291,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
- [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
- [Ollamac](https://github.com/kevinhermawan/Ollamac)
- [big-AGI](https://github.com/enricoros/big-AGI/blob/main/docs/config-local-ollama.md)
- [big-AGI](https://github.com/enricoros/big-AGI)
- [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
- [Amica](https://github.com/semperai/amica)
- [chatd](https://github.com/BruceMacD/chatd)
@@ -348,7 +348,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot and Ollama4j
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
- [Claude Dev](https://github.com/saoudrizwan/claude-dev) - VSCode extension for multi-file/whole-repo coding
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VSCode extension for multi-file/whole-repo coding
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
@@ -440,6 +440,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis.
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull and download models from Ollama Registry in your terminal.
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
### Apple Vision Pro

View File

@@ -163,19 +163,65 @@ func (t *ToolCallFunctionArguments) String() string {
type Tool struct {
Type string `json:"type"`
Items any `json:"items,omitempty"`
Function ToolFunction `json:"function"`
}
// PropertyType can be either a string or an array of strings
type PropertyType []string
// UnmarshalJSON implements the json.Unmarshaler interface
func (pt *PropertyType) UnmarshalJSON(data []byte) error {
// Try to unmarshal as a string first
var s string
if err := json.Unmarshal(data, &s); err == nil {
*pt = []string{s}
return nil
}
// If that fails, try to unmarshal as an array of strings
var a []string
if err := json.Unmarshal(data, &a); err != nil {
return err
}
*pt = a
return nil
}
// MarshalJSON implements the json.Marshaler interface
func (pt PropertyType) MarshalJSON() ([]byte, error) {
if len(pt) == 1 {
// If there's only one type, marshal as a string
return json.Marshal(pt[0])
}
// Otherwise marshal as an array
return json.Marshal([]string(pt))
}
// String returns a string representation of the PropertyType
func (pt PropertyType) String() string {
if len(pt) == 0 {
return ""
}
if len(pt) == 1 {
return pt[0]
}
return fmt.Sprintf("%v", []string(pt))
}
type ToolFunction struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
Type PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
} `json:"properties"`
} `json:"parameters"`
}

View File

@@ -231,3 +231,144 @@ func TestMessage_UnmarshalJSON(t *testing.T) {
}
}
}
func TestToolFunction_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
input string
wantErr string
}{
{
name: "valid enum with same types",
input: `{
"name": "test",
"description": "test function",
"parameters": {
"type": "object",
"required": ["test"],
"properties": {
"test": {
"type": "string",
"description": "test prop",
"enum": ["a", "b", "c"]
}
}
}
}`,
wantErr: "",
},
{
name: "empty enum array",
input: `{
"name": "test",
"description": "test function",
"parameters": {
"type": "object",
"required": ["test"],
"properties": {
"test": {
"type": "string",
"description": "test prop",
"enum": []
}
}
}
}`,
wantErr: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var tf ToolFunction
err := json.Unmarshal([]byte(tt.input), &tf)
if tt.wantErr != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.wantErr)
} else {
require.NoError(t, err)
}
})
}
}
func TestPropertyType_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
input string
expected PropertyType
}{
{
name: "string type",
input: `"string"`,
expected: PropertyType{"string"},
},
{
name: "array of types",
input: `["string", "number"]`,
expected: PropertyType{"string", "number"},
},
{
name: "array with single type",
input: `["string"]`,
expected: PropertyType{"string"},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var pt PropertyType
if err := json.Unmarshal([]byte(test.input), &pt); err != nil {
t.Errorf("Unexpected error: %v", err)
}
if len(pt) != len(test.expected) {
t.Errorf("Length mismatch: got %v, expected %v", len(pt), len(test.expected))
}
for i, v := range pt {
if v != test.expected[i] {
t.Errorf("Value mismatch at index %d: got %v, expected %v", i, v, test.expected[i])
}
}
})
}
}
func TestPropertyType_MarshalJSON(t *testing.T) {
tests := []struct {
name string
input PropertyType
expected string
}{
{
name: "single type",
input: PropertyType{"string"},
expected: `"string"`,
},
{
name: "multiple types",
input: PropertyType{"string", "number"},
expected: `["string","number"]`,
},
{
name: "empty type",
input: PropertyType{},
expected: `[]`,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
data, err := json.Marshal(test.input)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if string(data) != test.expected {
t.Errorf("Marshaled data mismatch: got %v, expected %v", string(data), test.expected)
}
})
}
}

View File

@@ -1381,7 +1381,6 @@ func NewCLI() *cobra.Command {
envVars["OLLAMA_NOPRUNE"],
envVars["OLLAMA_ORIGINS"],
envVars["OLLAMA_SCHED_SPREAD"],
envVars["OLLAMA_TMPDIR"],
envVars["OLLAMA_FLASH_ATTENTION"],
envVars["OLLAMA_KV_CACHE_TYPE"],
envVars["OLLAMA_LLM_LIBRARY"],

View File

@@ -26,7 +26,6 @@ When you run Ollama on **Windows**, there are a few different locations. You can
- `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log`
- `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH)
- `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored
- `explorer %TEMP%` where temporary executable files are stored in one or more `ollama*` directories
To enable additional debug logging to help troubleshoot problems, first **Quit the running app from the tray menu** then in a powershell terminal
@@ -69,10 +68,6 @@ If you run into problems on Linux and want to install an older version, or you'd
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.5.7 sh
```
## Linux tmp noexec
If your system is configured with the "noexec" flag where Ollama stores its temporary executable files, you can specify an alternate location by setting OLLAMA_TMPDIR to a location writable by the user ollama runs as. For example OLLAMA_TMPDIR=/usr/share/ollama/
## Linux docker
If Ollama initially works on the GPU in a docker container, but then switches to running on CPU after some period of time with errors in the server log reporting GPU discovery failures, this can be resolved by disabling systemd cgroup management in Docker. Edit `/etc/docker/daemon.json` on the host and add `"exec-opts": ["native.cgroupdriver=cgroupfs"]` to the docker configuration.

View File

@@ -62,7 +62,6 @@ the explorer window by hitting `<Ctrl>+R` and type in:
- *upgrade.log* contains log output for upgrades
- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
- `explorer %HOMEPATH%\.ollama` contains models and configuration
- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories
## Uninstall

View File

@@ -52,8 +52,8 @@ func TestMaxQueue(t *testing.T) {
embedCtx := ctx
var genwg sync.WaitGroup
genwg.Add(1)
go func() {
genwg.Add(1)
defer genwg.Done()
slog.Info("Starting generate request")
DoGenerate(ctx, t, client, req, resp, 45*time.Second, 5*time.Second)
@@ -71,8 +71,8 @@ func TestMaxQueue(t *testing.T) {
counterMu := sync.Mutex{}
var embedwg sync.WaitGroup
for i := 0; i < threadCount; i++ {
embedwg.Add(1)
go func(i int) {
embedwg.Add(1)
defer embedwg.Done()
slog.Info("embed started", "id", i)
embedReq := api.EmbeddingRequest{

View File

@@ -56,8 +56,9 @@ type Cache interface {
// StartForward is called before the start of the model's forward pass.
// For each token in the coming batch, there must be a corresponding
// entry in positions and seqs.
StartForward(ctx ml.Context, batch input.Batch) error
// entry in positions and seqs. reserve is to preallocate memory
// without actually storing data in the cache.
StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix(srcSeq, dstSeq int, len int32)

View File

@@ -146,51 +146,60 @@ func (c *Causal) Close() {
}
}
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error {
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
c.curBatchSize = len(batch.Positions)
c.curSequences = batch.Sequences
c.curPositions = batch.Positions
c.opts.Except = nil
c.updateSlidingWindow()
if !reserve {
c.updateSlidingWindow()
var err error
c.curLoc, err = c.findStartLoc()
if errors.Is(err, ErrKvCacheFull) {
c.defrag()
c.curLoc, err = c.findStartLoc()
}
if err != nil {
return err
}
c.curCellRange = newRange()
for i, pos := range batch.Positions {
seq := batch.Sequences[i]
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
seqRange, ok := c.cellRanges[seq]
if !ok {
seqRange = newRange()
}
if c.curLoc+i > seqRange.max {
seqRange.max = c.curLoc + i
}
if seqRange.max > c.curCellRange.max {
c.curCellRange.max = seqRange.max
}
if c.curLoc+i < seqRange.min {
seqRange.min = c.curLoc + i
}
if seqRange.min < c.curCellRange.min {
c.curCellRange.min = seqRange.min
}
c.cellRanges[seq] = seqRange
}
} else {
// If we are reserving memory, don't update any of the cache metadata but set the size
// to the worst case.
c.curLoc = 0
c.curCellRange.min = 0
c.curCellRange.max = len(c.cells) - 1
}
var err error
c.curLoc, err = c.findStartLoc()
if errors.Is(err, ErrKvCacheFull) {
c.defrag()
c.curLoc, err = c.findStartLoc()
}
if err != nil {
return err
}
c.curCellRange = newRange()
for i, pos := range batch.Positions {
seq := batch.Sequences[i]
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
seqRange, ok := c.cellRanges[seq]
if !ok {
seqRange = newRange()
}
if c.curLoc+i > seqRange.max {
seqRange.max = c.curLoc + i
}
if seqRange.max > c.curCellRange.max {
c.curCellRange.max = seqRange.max
}
if c.curLoc+i < seqRange.min {
seqRange.min = c.curLoc + i
}
if seqRange.min < c.curCellRange.min {
c.curCellRange.min = seqRange.min
}
c.cellRanges[seq] = seqRange
}
c.curMask, err = c.buildMask(ctx)
return err

View File

@@ -5,7 +5,6 @@ import (
"slices"
"testing"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
@@ -281,7 +280,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
context := backend.NewContext()
defer context.Close()
err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs})
err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false)
if err != nil {
panic(err)
}
@@ -315,7 +314,7 @@ func TestCanResume(t *testing.T) {
err := cache.StartForward(context, input.Batch{
Positions: []int32{0, 1, 2, 3},
Sequences: []int{0, 0, 0, 0},
})
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
@@ -342,7 +341,7 @@ func TestCanResume(t *testing.T) {
err = cache.StartForward(context, input.Batch{
Positions: []int32{4, 5},
Sequences: []int{0, 0},
})
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
@@ -372,14 +371,8 @@ func TestCanResume(t *testing.T) {
}
}
type testBackend struct{}
func (b *testBackend) Config() fs.Config {
panic("not implemented")
}
func (b *testBackend) Get(name string) ml.Tensor {
panic("not implemented")
type testBackend struct {
ml.Backend
}
func (b *testBackend) NewContext() ml.Context {
@@ -390,12 +383,10 @@ func (b *testBackend) NewContextSize(int) ml.Context {
return &testContext{}
}
func (b *testBackend) SystemInfo() string {
return "not implemented"
type testContext struct {
ml.Context
}
type testContext struct{}
func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
total := 0
@@ -440,6 +431,8 @@ func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
func (c *testContext) Compute(...ml.Tensor) {}
func (c *testContext) Reserve() error { return nil }
func (c *testContext) MaxGraphNodes() int {
return 10
}
@@ -447,6 +440,8 @@ func (c *testContext) MaxGraphNodes() int {
func (c *testContext) Close() {}
type testTensor struct {
ml.Tensor
dtype ml.DType
elementSize int
data []float32
@@ -474,10 +469,6 @@ func (t *testTensor) DType() ml.DType {
return t.dtype
}
func (t *testTensor) Bytes() []byte {
panic("not implemented")
}
func (t *testTensor) Floats() []float32 {
out := make([]float32, len(t.data))
copy(out, t.data)
@@ -502,64 +493,6 @@ func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return out
}
func (t *testTensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Softmax(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) LayerNorm(ctx ml.Context, weight, bias ml.Tensor, eps float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) RMSNorm(ctx ml.Context, weight ml.Tensor, eps float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) IM2Col(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Cos(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) Sin(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) GELU(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) SILU(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
offset /= t.elementSize
@@ -582,43 +515,7 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
return view
}
func (t *testTensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Repeat(ctx ml.Context, dim, n int) ml.Tensor { panic("not implemented") }
func (t *testTensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
copy(t2.(*testTensor).data, t.data)
return nil
}
func (t *testTensor) Duplicate(ctx ml.Context) ml.Tensor { panic("not implemented") }

View File

@@ -27,6 +27,11 @@ type EncoderCache struct {
// anything will be stored)
curPos int32
// curReserve indicates that this forward pass is only for
// memory reservation and we should not update our metadata
// based on it.
curReserve bool
// ** cache metadata **
// was something stored in the cache?
@@ -83,12 +88,14 @@ func (c *EncoderCache) Close() {
}
}
func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch) error {
func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
// We work with the most recent image
if len(batch.Multimodal) > 0 {
c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
}
c.curReserve = reserve
return nil
}
@@ -105,8 +112,10 @@ func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
}
func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.encoderPos = c.curPos
c.encoderCached = true
if !c.curReserve {
c.encoderPos = c.curPos
c.encoderCached = true
}
if c.config.PermutedV {
value = value.Permute(ctx, 1, 2, 0, 3)

View File

@@ -41,9 +41,9 @@ func (c *WrapperCache) Close() {
}
}
func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch) error {
func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
for i, cache := range c.caches {
err := cache.StartForward(ctx, batch)
err := cache.StartForward(ctx, batch, reserve)
if err != nil {
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
for j := i - 1; j >= 0; j-- {

View File

@@ -97,6 +97,13 @@ type Context interface {
Forward(...Tensor) Context
Compute(...Tensor)
// Reserve is analogous to Compute but rather than executing a
// graph, simply preallocates memory. Typically called with a
// worst case graph to ensure all resources are available for
// for future inference.
Reserve() error
MaxGraphNodes() int
Close()

View File

@@ -10,6 +10,7 @@ import "C"
import (
"context"
"errors"
"fmt"
"io"
"log/slog"
@@ -42,8 +43,12 @@ func devices() []*C.struct_ggml_backend_device {
}
type Backend struct {
meta *fsggml.GGML
sched *C.struct_ggml_backend_sched
meta *fsggml.GGML
sched *C.struct_ggml_backend_sched
schedBackends []*C.struct_ggml_backend
schedBufts []*C.struct_ggml_backend_buffer_type
tensors map[string]*C.struct_ggml_tensor
// input is the backend used for inputs
@@ -281,6 +286,10 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
}
b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt)
if b == nil {
return nil, fmt.Errorf("unable to allocate memory from device %v for model weights", C.GoString(C.ggml_backend_buft_name(bt)))
}
C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS)
bbs[c] = b
}
@@ -385,8 +394,6 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
schedBackends = append(schedBackends, b)
schedBufts = append(schedBufts, bt)
slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(b)), "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
if C.ggml_backend_is_cpu(b) {
// set number of threads for cpu backend
C.ggml_backend_cpu_set_n_threads(b, C.int(Threads(params.NumThreads)))
@@ -405,7 +412,9 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
C.size_t(maxGraphNodes),
C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
),
input: deviceBufferTypes[input.d],
schedBackends: schedBackends,
schedBufts: schedBufts,
input: deviceBufferTypes[input.d],
layers: func() map[int]*C.struct_ggml_backend_buffer_type {
m := make(map[int]*C.struct_ggml_backend_buffer_type)
for i, layer := range layers {
@@ -530,6 +539,24 @@ func (c Context) Compute(tensors ...ml.Tensor) {
}
}
func (c Context) Reserve() error {
if !C.ggml_backend_sched_reserve(c.b.sched, c.graph) {
C.ggml_backend_sched_reset(c.b.sched)
return errors.New("failed to reserve graph")
}
slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched))
for i := range c.b.schedBackends {
size := C.ggml_backend_sched_get_buffer_size(c.b.sched, c.b.schedBackends[i])
slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])),
"size", format.HumanBytes2(uint64(size)))
}
C.ggml_backend_sched_reset(c.b.sched)
return nil
}
func (c Context) MaxGraphNodes() int {
return c.maxGraphNodes
}
@@ -547,9 +574,9 @@ func pad(length, pad C.size_t) C.size_t {
return ((length + pad - 1) / pad) * pad
}
func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
func (c Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
if c.buft == nil {
panic("set Input, Output, or Layer before creating tensors")
panic("set Input or Layer before creating tensors")
}
var cdtype uint32
@@ -570,7 +597,7 @@ func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
if len(shape) < 1 || shape[0] == 0 {
var shape C.int64_t = 0
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}, nil
} else if len(shape) > 4 {
panic("unsupported number of dimensions")
}
@@ -584,16 +611,29 @@ func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape))
size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft))
b := C.ggml_backend_buft_alloc_buffer(c.buft, size)
if b == nil {
return nil, fmt.Errorf("unable to allocate %v from device %v for new tensor", format.HumanBytes2(uint64(size)), C.GoString(C.ggml_backend_buft_name(c.buft)))
}
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
return &Tensor{b: c.b, t: t}
return &Tensor{b: c.b, t: t}, nil
}
func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
return c.newTensor(dtype, shape)
t, err := c.newTensor(dtype, shape)
if err != nil {
panic(err)
}
return t
}
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
t := c.newTensor(dtype, shape)
t, err := c.newTensor(dtype, shape)
if err != nil {
panic(err)
}
C.ggml_set_zero(t.(*Tensor).t)
return t
}
@@ -621,7 +661,11 @@ func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
return nil, err
}
t := c.newTensor(ml.DTypeF32, shape)
t, err := c.newTensor(ml.DTypeF32, shape)
if err != nil {
return nil, err
}
if len(s) > 0 {
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
}
@@ -634,7 +678,11 @@ func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
return nil, err
}
t := c.newTensor(ml.DTypeI32, shape)
t, err := c.newTensor(ml.DTypeI32, shape)
if err != nil {
return nil, err
}
if len(s) > 0 {
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
}

View File

@@ -299,7 +299,7 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten
cache := m.Config().Cache
if cache != nil {
err := cache.StartForward(ctx, batch)
err := cache.StartForward(ctx, batch, false)
if err != nil {
return nil, err
}

View File

@@ -281,27 +281,31 @@ func TestChatMiddleware(t *testing.T) {
Description: "Get the current weather",
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
} `json:"properties"`
}{
Type: "object",
Required: []string{"location"},
Properties: map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
}{
"location": {
Type: "string",
Type: api.PropertyType{"string"},
Description: "The city and state",
},
"unit": {
Type: "string",
Enum: []string{"celsius", "fahrenheit"},
Type: api.PropertyType{"string"},
Enum: []any{"celsius", "fahrenheit"},
},
},
},

View File

@@ -11,10 +11,13 @@ import (
"os"
"os/user"
"path/filepath"
"runtime"
"slices"
"strconv"
"strings"
"sync"
"golang.org/x/sync/errgroup"
"golang.org/x/text/encoding/unicode"
"golang.org/x/text/transform"
@@ -144,12 +147,25 @@ func fileDigestMap(path string) (map[string]string, error) {
files = []string{path}
}
var mu sync.Mutex
var g errgroup.Group
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
for _, f := range files {
digest, err := digestForFile(f)
if err != nil {
return nil, err
}
fl[f] = digest
g.Go(func() error {
digest, err := digestForFile(f)
if err != nil {
return err
}
mu.Lock()
defer mu.Unlock()
fl[f] = digest
return nil
})
}
if err := g.Wait(); err != nil {
return nil, err
}
return fl, nil

View File

@@ -448,7 +448,7 @@ func (m *mockCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor) {}
func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {}
func (m *mockCache) Close() {}
func (m *mockCache) StartForward(ctx ml.Context, batch input.Batch) error { return nil }
func (m *mockCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { return nil }
func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {}
func (m *mockCache) SetConfig(ml.CacheConfig) {}
func (m *mockCache) CanResume(seq int, pos int32) bool { return true }

View File

@@ -728,6 +728,51 @@ func (m *multiLPath) String() string {
return strings.Join(*m, ", ")
}
func (s *Server) reserveWorstCaseGraph() error {
ctx := s.model.Backend().NewContext()
defer ctx.Close()
var batch input.Batch
inputs := make([]int32, s.batchSize)
batch.Positions = make([]int32, len(inputs))
batch.Sequences = make([]int, len(inputs))
for i := range inputs {
batch.Positions[i] = int32(i)
}
batch.Outputs = make([]int32, s.parallel)
for i := range batch.Outputs {
batch.Outputs[i] = int32(i)
}
var err error
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
if err != nil {
return err
}
cache := s.model.Config().Cache
if cache != nil {
err := cache.StartForward(ctx, batch, true)
if err != nil {
return err
}
}
t, err := s.model.Forward(ctx, batch)
if err != nil {
return err
}
err = ctx.Forward(t).Reserve()
if err != nil {
return err
}
return nil
}
func (s *Server) loadModel(
ctx context.Context,
mpath string,
@@ -765,6 +810,11 @@ func (s *Server) loadModel(
s.seqs = make([]*Sequence, s.parallel)
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
err = s.reserveWorstCaseGraph()
if err != nil {
panic(err)
}
s.status = llm.ServerStatusReady
s.ready.Done()
}

View File

@@ -10,6 +10,7 @@ import (
"log/slog"
"net/http"
"os"
"regexp"
"slices"
"strings"
"text/template/parse"
@@ -153,99 +154,342 @@ func parseObjects(s string) []map[string]any {
return objs
}
// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls.
// mxyng: this only really works if the input contains tool calls in some JSON format
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
// create a subtree from the node that ranges over .ToolCalls
// Get tool call token from model template
func (m *Model) TemplateToolToken() (string, string, bool) {
// Try to detect the tool call format from the model's template
tmpl := m.Template.Subtree(func(n parse.Node) bool {
if t, ok := n.(*parse.RangeNode); ok {
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
}
return false
})
if tmpl == nil {
return nil, false
}
var b bytes.Buffer
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
"ToolCalls": {
{
Function: api.ToolCallFunction{
Name: "@@name@@",
Arguments: api.ToolCallFunctionArguments{
"@@argument@@": 1,
// fmt.Println("tool call template", tmpl)
if tmpl != nil {
// Execute template with test data to see the format
var b bytes.Buffer
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
"ToolCalls": {
{
Function: api.ToolCallFunction{
Name: "function_name",
Arguments: api.ToolCallFunctionArguments{
"argument1": "value1",
// "argument2": "value2",
},
},
},
},
},
}); err != nil {
return nil, false
}
templateObjects := parseObjects(b.String())
if len(templateObjects) == 0 {
return nil, false
}
// find the keys that correspond to the name and arguments fields
var name, arguments string
for k, v := range templateObjects[0] {
switch v.(type) {
case string:
name = k
case map[string]any:
arguments = k
}
}
if name == "" || arguments == "" {
return nil, false
}
responseObjects := parseObjects(s)
if len(responseObjects) == 0 {
return nil, false
}
// collect all nested objects
var collect func(any) []map[string]any
collect = func(obj any) (all []map[string]any) {
switch o := obj.(type) {
case map[string]any:
all = append(all, o)
for _, v := range o {
all = append(all, collect(v)...)
}
case []any:
for _, v := range o {
all = append(all, collect(v)...)
}); err == nil {
// Look for special tokens in the template output
output := strings.TrimSpace(b.String())
slog.Debug("tool call template output", "output", output)
if strings.Contains(output, "<") {
// Extract the special token between < and >
start := strings.Index(output, "<")
end := strings.Index(output, ">")
if start >= 0 && end > start {
token := output[start : end+1]
return output, token, true
}
} else if strings.Contains(output, "[") {
// Check if it's a tool call token rather than JSON array
start := strings.Index(output, "[")
end := strings.Index(output, "]")
if start >= 0 && end > start {
token := output[start : end+1]
// Only consider it a token if it's not valid JSON
var jsonTest any
if err := json.Unmarshal([]byte(token), &jsonTest); err != nil {
return output, token, true
}
}
}
}
return all
}
return "", "", false
}
var objs []map[string]any
for _, p := range responseObjects {
objs = append(objs, collect(p)...)
func parsePythonFunctionCall(s string) ([]api.ToolCall, bool) {
re := regexp.MustCompile(`(\w+)\((.*?)\)`)
matches := re.FindAllStringSubmatchIndex(s, -1)
if len(matches) == 0 {
return nil, false
}
var toolCalls []api.ToolCall
for _, kv := range objs {
n, nok := kv[name].(string)
a, aok := kv[arguments].(map[string]any)
if nok && aok {
for _, match := range matches {
name := s[match[2]:match[3]]
args := s[match[4]:match[5]]
arguments := make(api.ToolCallFunctionArguments)
if strings.Contains(args, "=") { // Keyword args
pairs := strings.SplitSeq(args, ",")
for pair := range pairs {
pair = strings.TrimSpace(pair)
kv := strings.Split(pair, "=")
if len(kv) == 2 {
key := strings.TrimSpace(kv[0])
value := strings.TrimSpace(kv[1])
arguments[key] = value
}
}
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: n,
Arguments: a,
Name: name,
Arguments: arguments,
},
})
}
}
return toolCalls, len(toolCalls) > 0
if len(toolCalls) > 0 {
return toolCalls, true
}
return nil, false
}
// ToolCallFormat represents different possible formats for tool calls
type toolCallFormat struct {
// Direct format
Name string `json:"name,omitempty"`
Arguments map[string]any `json:"arguments,omitempty"`
// Command-r-plus format
ToolName string `json:"tool_name,omitempty"`
Parameters map[string]any `json:"parameters,omitempty"`
// Function format
Function *struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments,omitempty"`
Parameters map[string]any `json:"parameters,omitempty"`
} `json:"function,omitempty"`
// Xlam format
ToolCalls []toolCallFormat `json:"tool_calls,omitempty"`
}
func parseJSONToolCalls(obj map[string]any) ([]api.ToolCall, bool) {
// Helper to convert any to []any safely
toArray := func(v any) []any {
if arr, ok := v.([]any); ok {
return arr
}
return nil
}
// Convert a single format to a tool call
makeToolCall := func(f toolCallFormat) (api.ToolCall, bool) {
switch {
case f.Name != "" && f.Arguments != nil:
return api.ToolCall{
Function: api.ToolCallFunction{
Name: f.Name,
Arguments: f.Arguments,
},
}, true
case f.Name != "" && f.Parameters != nil: // Handle parameters field
return api.ToolCall{
Function: api.ToolCallFunction{
Name: f.Name,
Arguments: f.Parameters,
},
}, true
case f.ToolName != "" && f.Parameters != nil:
return api.ToolCall{
Function: api.ToolCallFunction{
Name: f.ToolName,
Arguments: f.Parameters,
},
}, true
case f.Function != nil && f.Function.Name != "":
args := f.Function.Arguments
if args == nil {
args = f.Function.Parameters
}
if args != nil {
return api.ToolCall{
Function: api.ToolCallFunction{
Name: f.Function.Name,
Arguments: args,
},
}, true
}
}
return api.ToolCall{}, false
}
// Try parsing as array first
if arr := toArray(obj); arr != nil {
var calls []api.ToolCall
for _, item := range arr {
if itemMap, ok := item.(map[string]any); ok {
var format toolCallFormat
data, _ := json.Marshal(itemMap)
if err := json.Unmarshal(data, &format); err == nil {
if call, ok := makeToolCall(format); ok {
calls = append(calls, call)
}
}
}
}
if len(calls) > 0 {
return calls, true
}
}
// Try parsing as single object
var format toolCallFormat
data, _ := json.Marshal(obj)
if err := json.Unmarshal(data, &format); err != nil {
return nil, false
}
// Handle xlam format (tool_calls array)
if len(format.ToolCalls) > 0 {
var calls []api.ToolCall
for _, f := range format.ToolCalls {
if call, ok := makeToolCall(f); ok {
calls = append(calls, call)
}
}
if len(calls) > 0 {
return calls, true
}
}
// Try as single tool call
if call, ok := makeToolCall(format); ok {
return []api.ToolCall{call}, true
}
return nil, false
}
// token, partial, success
func deriveToolToken(s string, prefix string) (string, bool, bool) {
// There shouldn't be spaces in a tool token
if len(strings.Fields(s)) > 1 {
return "", false, false
}
if prefix == "[" && len(s) > 1 && s[len(s)-1] == ']' {
return s, false, true
} else if prefix == "<" && len(s) > 1 && s[len(s)-1] == '>' {
return s, false, true
}
return "", true, true
}
func parseJSON(s string) ([]api.ToolCall, bool) {
objs := parseObjects(s)
tcs := []api.ToolCall{}
for _, obj := range objs {
toolCalls, ok := parseJSONToolCalls(obj)
if ok {
tcs = append(tcs, toolCalls...)
}
}
if len(tcs) > 0 {
return tcs, true
}
return nil, false
}
// returns tool calls, partial, success
func (m *Model) ParseToolCalls(s string, toolToken *string) ([]api.ToolCall, bool, bool) {
// [ case can either be JSON, Python or a Tool Token
s = strings.TrimSpace(s)
fmt.Printf("ParseToolCallsNew input: %q\n", s)
if len(s) == 0 {
return nil, false, false
}
if strings.HasPrefix(s, "[") {
fmt.Println("Found [ prefix")
// JSON case
// we do not consider array JSONs as tool calls
if strings.HasPrefix(s, "[{") {
fmt.Println("Found [{ prefix - attempting JSON parse")
// TODO: mark as JSON partial
if calls, ok := parseJSON(s); ok {
fmt.Printf("Successfully parsed JSON, found %d calls\n", len(calls))
return calls, false, true
}
return nil, true, true
}
// Python Case
// We just do a full python check here
fmt.Println("Attempting Python function parse")
tc, ok := parsePythonFunctionCall(s)
if ok {
fmt.Printf("Successfully parsed Python function: %+v\n", tc)
return tc, false, true
}
// Tool Token Case - this is okay if it's a real tool token and we couldn't get from template
fmt.Println("Attempting to derive tool token")
if toolToken == nil || *toolToken == "" {
toolTok, partial, ok := deriveToolToken(s, "[")
if !ok {
return nil, false, false
}
if partial {
return nil, true, true
}
*toolToken = toolTok
}
fmt.Printf("Found tool token: %q\n", *toolToken)
s = strings.TrimSpace(s[len(*toolToken):])
fmt.Printf("Recursing with remaining string: %q\n", s)
if toolCalls, partial, ok := m.ParseToolCalls(s, toolToken); ok {
return toolCalls, partial, true
}
return nil, true, true
} else if strings.HasPrefix(s, "{") || strings.HasPrefix(s, "```") {
// // TODO: temp fix
// if strings.HasPrefix(s, "```") && len(s) == 3 {
// return nil, false, false
// }
fmt.Println("Found { prefix - attempting JSON parse with ", s)
if calls, ok := parseJSON(s); ok {
fmt.Printf("Successfully parsed JSON object, found %d calls\n", len(calls))
return calls, false, true
}
fmt.Println("Failed to parse JSON in JSON case")
// TODO: possible case where it never finishes parsing - then what?
return nil, true, true
} else if strings.HasPrefix(s, "<") {
fmt.Println("Found < prefix - attempting to derive tool token")
if toolToken == nil || *toolToken == "" {
toolTok, partial, ok := deriveToolToken(s, "<")
if !ok {
return nil, false, false
}
if partial {
return nil, true, true
}
*toolToken = toolTok
fmt.Printf("Found tool token: %q\n", *toolToken)
}
fmt.Printf("Found tool token: %q\n", *toolToken)
s = strings.TrimSpace(s[len(*toolToken):])
fmt.Printf("Recursing with remaining string: %q\n", s)
if toolCalls, partial, ok := m.ParseToolCalls(s, toolToken); ok {
return toolCalls, partial, true
}
return nil, true, true
} else if strings.Contains(s, "(") || len(strings.Fields(s)) == 1 {
fmt.Println("Attempting Python function parse")
tc, ok := parsePythonFunctionCall(s)
if ok {
fmt.Printf("Successfully parsed Python function: %+v\n", tc)
return tc, false, true
}
fmt.Printf("Failed to parse Python function: %q, returning partial", s)
return nil, true, true
}
fmt.Println("No successful parse paths found")
fmt.Printf("failed string: %q\n", s)
return nil, false, false
}

View File

@@ -1526,6 +1526,17 @@ func (s *Server) ChatHandler(c *gin.Context) {
defer close(ch)
var sb strings.Builder
var toolCallIndex int = 0
var sentWithTools int = 0
// var prefix string
// var templateToolToken string
_, templateToolToken, _ := m.TemplateToolToken()
// fmt.Println("special token", templateToolToken)
var minDuration time.Duration = math.MaxInt64
var maxDuration time.Duration
var totalDuration time.Duration
var checkCount int
const maxToolTokens = 1
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
@@ -1546,6 +1557,14 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
if r.Done {
slog.Debug("min duration", "duration", minDuration)
slog.Debug("max duration", "duration", maxDuration)
slog.Debug("total duration", "duration", totalDuration)
slog.Debug("check count", "count", checkCount)
// slog.Debug("average duration", "duration", totalDuration/time.Duration(checkCount))
// if sb.Len() > 0 {
// res.Message.Content = sb.String()
// }
res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
@@ -1563,25 +1582,48 @@ func (s *Server) ChatHandler(c *gin.Context) {
// If tools are recognized, use a flag to track the sending of a tool downstream
// This ensures that content is cleared from the message on the last chunk sent
sb.WriteString(r.Content)
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
res.Message.ToolCalls = toolCalls
for i := range toolCalls {
toolCalls[i].Function.Index = toolCallIndex
toolCallIndex++
startTime := time.Now()
// TODO: work max tool tok logic
if len(req.Tools) > 0 && sentWithTools < maxToolTokens {
toolCalls, partial, ok := m.ParseToolCalls(sb.String(), &templateToolToken)
duration := time.Since(startTime)
checkCount++
minDuration = min(minDuration, duration)
maxDuration = max(maxDuration, duration)
totalDuration += duration
slog.Debug("tool call duration", "duration", duration)
if ok {
// fmt.Println("toolCalls", toolCalls, partial, ok, duration)
if partial {
// If the tool call is partial, we need to wait for the next chunk
return
}
slog.Debug("toolCalls", "toolCalls", toolCalls, "partial", partial, "ok", ok)
res.Message.ToolCalls = toolCalls
for i := range toolCalls {
toolCalls[i].Function.Index = toolCallIndex
toolCallIndex++
}
sentWithTools = 0
// prefix = ""
templateToolToken = ""
res.Message.Content = ""
sb.Reset()
ch <- res
// TODO: revisit this
sentWithTools++
slog.Debug("fired on tool call", "toolCalls", toolCalls, "toolCallIndex", toolCallIndex)
return
}
res.Message.Content = ""
sb.Reset()
ch <- res
return
}
if r.Done {
// Send any remaining content if no tool calls were detected
if toolCallIndex == 0 {
res.Message.Content = sb.String()
}
ch <- res
}
// Send any remaining content if no tool calls were detected
// if toolCallIndex == 0 {
// fmt.Println("toolCallIndex", toolCallIndex)
sentWithTools++
res.Message.Content = sb.String()
sb.Reset()
ch <- res
}); err != nil {
ch <- gin.H{"error": err.Error()}
}
@@ -1590,11 +1632,33 @@ func (s *Server) ChatHandler(c *gin.Context) {
if req.Stream != nil && !*req.Stream {
var resp api.ChatResponse
var sb strings.Builder
var toolCalls []api.ToolCall
const MAX_TOOL_TOKENS = 1
sentWithTools := 0
var tb strings.Builder
_, templateToolToken, _ := m.TemplateToolToken()
for rr := range ch {
switch t := rr.(type) {
case api.ChatResponse:
sb.WriteString(t.Message.Content)
resp = t
// TODO: work max tool tok logic
if len(req.Tools) > 0 && sentWithTools < MAX_TOOL_TOKENS {
tb.WriteString(t.Message.Content)
if tcs, partial, ok := m.ParseToolCalls(tb.String(), &templateToolToken); ok {
if !partial {
// resp.Message.ToolCalls = toolCalls
toolCalls = append(toolCalls, tcs...)
resp.Message.Content = ""
tb.Reset()
}
} else {
// equivalent to no partial - send the content downstream
tb.Reset()
sentWithTools++
}
}
case gin.H:
msg, ok := t["error"].(string)
if !ok {
@@ -1610,14 +1674,18 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
resp.Message.Content = sb.String()
if len(req.Tools) > 0 {
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
resp.Message.ToolCalls = toolCalls
resp.Message.Content = ""
}
if len(toolCalls) > 0 {
resp.Message.ToolCalls = toolCalls
// resp.Message.Content = ""
}
// if len(req.Tools) > 0 {
// if toolCalls, ok := m.ParseToolCalls(sb.String()); ok {
// resp.Message.ToolCalls = toolCalls
// resp.Message.Content = ""
// }
// }
c.JSON(http.StatusOK, resp)
return
}

View File

@@ -370,27 +370,31 @@ func TestGenerateChat(t *testing.T) {
Description: "Get the current weather",
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
} `json:"properties"`
}{
Type: "object",
Required: []string{"location"},
Properties: map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
}{
"location": {
Type: "string",
Type: api.PropertyType{"string"},
Description: "The city and state",
},
"unit": {
Type: "string",
Enum: []string{"celsius", "fahrenheit"},
Type: api.PropertyType{"string"},
Enum: []any{"celsius", "fahrenheit"},
},
},
},
@@ -467,27 +471,31 @@ func TestGenerateChat(t *testing.T) {
Description: "Get the current weather",
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
} `json:"properties"`
}{
Type: "object",
Required: []string{"location"},
Properties: map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
}{
"location": {
Type: "string",
Type: api.PropertyType{"string"},
Description: "The city and state",
},
"unit": {
Type: "string",
Enum: []string{"celsius", "fahrenheit"},
Type: api.PropertyType{"string"},
Enum: []any{"celsius", "fahrenheit"},
},
},
},

View File

@@ -667,13 +667,19 @@ func (runner *runnerRef) waitForVRAMRecovery() chan any {
return finished
}
type ByDuration []*runnerRef
type ByDurationAndName []*runnerRef
func (a ByDuration) Len() int { return len(a) }
func (a ByDuration) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByDuration) Less(i, j int) bool {
// uint64 to turn negative time (never unload) to largest
return uint64(a[i].sessionDuration) < uint64(a[j].sessionDuration)
func (a ByDurationAndName) Len() int { return len(a) }
func (a ByDurationAndName) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByDurationAndName) Less(i, j int) bool {
// Primary sort by session duration (uint64 to handle negatives)
d1 := uint64(a[i].sessionDuration)
d2 := uint64(a[j].sessionDuration)
if d1 != d2 {
return d1 < d2
}
// Secondary sort by model path lex order
return a[i].modelPath < a[j].modelPath
}
// TODO - future consideration to pick runners based on size
@@ -775,7 +781,7 @@ func (s *Scheduler) findRunnerToUnload() *runnerRef {
// In the future we can enhance the algorithm to be smarter about picking the optimal runner to unload
// e.g., if we have multiple options, will one make room for the request?
sort.Sort(ByDuration(runnerList))
sort.Sort(ByDurationAndName(runnerList))
// First try to find a runner that's already idle
for _, runner := range runnerList {