Compare commits
21 Commits
v0.4.6
...
parth/open
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2536ffe0ab | ||
|
|
97abd7bfea | ||
|
|
c6509bf76e | ||
|
|
aed1419c64 | ||
|
|
c6c526275d | ||
|
|
630e7dc6ff | ||
|
|
eb8366d658 | ||
|
|
4456012956 | ||
|
|
539be43640 | ||
|
|
1bdab9fdb1 | ||
|
|
2b82c5a8a1 | ||
|
|
55c3efa900 | ||
|
|
1aedffad93 | ||
|
|
ff6c2d6dc8 | ||
|
|
d543b282a7 | ||
|
|
5f8051180e | ||
|
|
39e29ae5dd | ||
|
|
30a9f063c9 | ||
|
|
7355ab3703 | ||
|
|
7ed81437fe | ||
|
|
220108d3f4 |
7
.github/workflows/test.yaml
vendored
7
.github/workflows/test.yaml
vendored
@@ -243,7 +243,7 @@ jobs:
|
||||
$env:PATH="$gopath;$gccpath;$env:PATH"
|
||||
echo $env:PATH
|
||||
if (!(gcc --version | select-string -quiet clang)) { throw "wrong gcc compiler detected - must be clang" }
|
||||
make -j 4
|
||||
make -j 4
|
||||
- name: 'Build Unix Go Runners'
|
||||
if: ${{ ! startsWith(matrix.os, 'windows-') }}
|
||||
run: make -j 4
|
||||
@@ -310,8 +310,7 @@ jobs:
|
||||
arm64) echo ARCH=arm64 ;;
|
||||
esac >>$GITHUB_ENV
|
||||
shell: bash
|
||||
- run: go build
|
||||
- run: go test -v ./...
|
||||
- run: go test ./...
|
||||
|
||||
patches:
|
||||
needs: [changes]
|
||||
@@ -323,4 +322,4 @@ jobs:
|
||||
submodules: recursive
|
||||
- name: Verify patches carry all the changes
|
||||
run: |
|
||||
make apply-patches sync && git diff --compact-summary --exit-code llama
|
||||
make apply-patches sync && git diff --compact-summary --exit-code llama
|
||||
|
||||
@@ -346,6 +346,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
|
||||
- [Promptery](https://github.com/promptery/promptery) (desktop client for Ollama.)
|
||||
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
||||
- [SpaceLlama](https://github.com/tcsenpai/spacellama) (Firefox and Chrome extension to quickly summarize web pages with ollama in a sidebar)
|
||||
- [YouLama](https://github.com/tcsenpai/youlama) (Webapp to quickly summarize any YouTube video, supporting Invidious as well)
|
||||
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
|
||||
- [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol)
|
||||
- [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app)
|
||||
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard and said in the meetings)
|
||||
@@ -356,6 +359,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Nosia](https://github.com/nosia-ai/nosia) (Easy to install and use RAG platform based on Ollama)
|
||||
- [Witsy](https://github.com/nbonamy/witsy) (An AI Desktop application avaiable for Mac/Windows/Linux)
|
||||
- [Abbey](https://github.com/US-Artificial-Intelligence/abbey) (A configurable AI interface server with notebooks, document storage, and YouTube support)
|
||||
- [Minima](https://github.com/dmayboroda/minima) (RAG with on-premises or fully local workflow)
|
||||
|
||||
### Cloud
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ type GenerateRequest struct {
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
|
||||
// Format specifies the format to return a response in.
|
||||
Format string `json:"format"`
|
||||
Format json.RawMessage `json:"format,omitempty"`
|
||||
|
||||
// KeepAlive controls how long the model will stay loaded in memory following
|
||||
// this request.
|
||||
@@ -94,7 +94,7 @@ type ChatRequest struct {
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
|
||||
// Format is the format to return the response in (e.g. "json").
|
||||
Format string `json:"format"`
|
||||
Format json.RawMessage `json:"format,omitempty"`
|
||||
|
||||
// KeepAlive controls how long the model will stay loaded into memory
|
||||
// following the request.
|
||||
@@ -146,6 +146,7 @@ type ToolCall struct {
|
||||
}
|
||||
|
||||
type ToolCallFunction struct {
|
||||
Index int `json:"index,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Arguments ToolCallFunctionArguments `json:"arguments"`
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -1038,7 +1039,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
req := &api.ChatRequest{
|
||||
Model: opts.Model,
|
||||
Messages: opts.Messages,
|
||||
Format: opts.Format,
|
||||
Format: json.RawMessage(opts.Format),
|
||||
Options: opts.Options,
|
||||
}
|
||||
|
||||
@@ -1125,7 +1126,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
Prompt: opts.Prompt,
|
||||
Context: generateContext,
|
||||
Images: opts.Images,
|
||||
Format: opts.Format,
|
||||
Format: json.RawMessage(opts.Format),
|
||||
System: opts.System,
|
||||
Options: opts.Options,
|
||||
KeepAlive: opts.KeepAlive,
|
||||
@@ -1445,6 +1446,7 @@ func NewCLI() *cobra.Command {
|
||||
envVars["OLLAMA_SCHED_SPREAD"],
|
||||
envVars["OLLAMA_TMPDIR"],
|
||||
envVars["OLLAMA_FLASH_ATTENTION"],
|
||||
envVars["OLLAMA_KV_CACHE_TYPE"],
|
||||
envVars["OLLAMA_LLM_LIBRARY"],
|
||||
envVars["OLLAMA_GPU_OVERHEAD"],
|
||||
envVars["OLLAMA_LOAD_TIMEOUT"],
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -180,18 +179,14 @@ Weigh anchor!
|
||||
|
||||
t.Run("license", func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
license, err := os.ReadFile(filepath.Join("..", "LICENSE"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
license := "MIT License\nCopyright (c) Ollama\n"
|
||||
if err := showInfo(&api.ShowResponse{
|
||||
Details: api.ModelDetails{
|
||||
Family: "test",
|
||||
ParameterSize: "7B",
|
||||
QuantizationLevel: "FP16",
|
||||
},
|
||||
License: string(license),
|
||||
License: license,
|
||||
}, &b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
@@ -60,7 +61,25 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
|
||||
addedTokens[t.Content] = t
|
||||
}
|
||||
|
||||
t.Merges = tt.Model.Merges
|
||||
if len(tt.Model.Merges) == 0 {
|
||||
// noop; merges is empty
|
||||
} else if err := json.Unmarshal(tt.Model.Merges, &t.Merges); err == nil {
|
||||
// noop; merges is []string
|
||||
} else if merges, err := func() ([][]string, error) {
|
||||
var merges [][]string
|
||||
if err := json.Unmarshal(tt.Model.Merges, &merges); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return merges, nil
|
||||
}(); err == nil {
|
||||
t.Merges = make([]string, len(merges))
|
||||
for i := range merges {
|
||||
t.Merges[i] = strings.Join(merges[i], " ")
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("could not parse tokenizer merges. expected []string or [][]string: %w", err)
|
||||
}
|
||||
|
||||
sha256sum := sha256.New()
|
||||
for _, pt := range tt.PreTokenizer.PreTokenizers {
|
||||
@@ -156,9 +175,9 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
|
||||
type tokenizer struct {
|
||||
AddedTokens []token `json:"added_tokens"`
|
||||
Model struct {
|
||||
Type string `json:"type"`
|
||||
Vocab map[string]int `json:"vocab"`
|
||||
Merges []string `json:"merges"`
|
||||
Type string `json:"type"`
|
||||
Vocab map[string]int `json:"vocab"`
|
||||
Merges json.RawMessage `json:"merges"`
|
||||
} `json:"model"`
|
||||
|
||||
PreTokenizer struct {
|
||||
|
||||
@@ -191,6 +191,62 @@ func TestParseTokenizer(t *testing.T) {
|
||||
Pre: "default",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "list string merges",
|
||||
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||
"tokenizer.json": strings.NewReader(`{
|
||||
"model": {
|
||||
"merges": [
|
||||
"a b",
|
||||
"c d",
|
||||
"e f"
|
||||
]
|
||||
}
|
||||
}`),
|
||||
}),
|
||||
want: &Tokenizer{
|
||||
Vocabulary: &Vocabulary{
|
||||
Model: "gpt2",
|
||||
},
|
||||
Merges: []string{
|
||||
"a b",
|
||||
"c d",
|
||||
"e f",
|
||||
},
|
||||
Pre: "default",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "list list string merges",
|
||||
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||
"tokenizer.json": strings.NewReader(`{
|
||||
"model": {
|
||||
"merges": [
|
||||
[
|
||||
"a", "b"
|
||||
],
|
||||
[
|
||||
"c", "d"
|
||||
],
|
||||
[
|
||||
"e", "f"
|
||||
]
|
||||
]
|
||||
}
|
||||
}`),
|
||||
}),
|
||||
want: &Tokenizer{
|
||||
Vocabulary: &Vocabulary{
|
||||
Model: "gpt2",
|
||||
},
|
||||
Merges: []string{
|
||||
"a b",
|
||||
"c d",
|
||||
"e f",
|
||||
},
|
||||
Pre: "default",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
|
||||
@@ -183,3 +183,17 @@ func (si SystemInfo) GetOptimalThreadCount() int {
|
||||
|
||||
return coreCount
|
||||
}
|
||||
|
||||
// For each GPU, check if it does NOT support flash attention
|
||||
func (l GpuInfoList) FlashAttentionSupported() bool {
|
||||
for _, gpu := range l {
|
||||
supportsFA := gpu.Library == "metal" ||
|
||||
(gpu.Library == "cuda" && gpu.DriverMajor >= 7) ||
|
||||
gpu.Library == "rocm"
|
||||
|
||||
if !supportsFA {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -49,10 +49,10 @@ Advanced parameters (optional):
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `system`: system message to (overrides what is defined in the `Modelfile`)
|
||||
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
|
||||
- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
|
||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
- `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
- `context` (deprecated): the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
|
||||
|
||||
#### JSON mode
|
||||
|
||||
|
||||
28
docs/faq.md
28
docs/faq.md
@@ -151,7 +151,7 @@ Refer to the section [above](#how-do-i-configure-ollama-server) for how to set e
|
||||
|
||||
Ollama runs an HTTP server and can be exposed using a proxy server such as Nginx. To do so, configure the proxy to forward requests and optionally set required headers (if not exposing Ollama on the network). For example, with Nginx:
|
||||
|
||||
```
|
||||
```nginx
|
||||
server {
|
||||
listen 80;
|
||||
server_name example.com; # Replace with your domain or IP
|
||||
@@ -285,4 +285,28 @@ Note: Windows with Radeon GPUs currently default to 1 model maximum due to limit
|
||||
|
||||
## How does Ollama load models on multiple GPUs?
|
||||
|
||||
Installing multiple GPUs of the same brand can be a great way to increase your available VRAM to load larger models. When you load a new model, Ollama evaluates the required VRAM for the model against what is currently available. If the model will entirely fit on any single GPU, Ollama will load the model on that GPU. This typically provides the best performance as it reduces the amount of data transfering across the PCI bus during inference. If the model does not fit entirely on one GPU, then it will be spread across all the available GPUs.
|
||||
When loading a new model, Ollama evaluates the required VRAM for the model against what is currently available. If the model will entirely fit on any single GPU, Ollama will load the model on that GPU. This typically provides the best performance as it reduces the amount of data transferring across the PCI bus during inference. If the model does not fit entirely on one GPU, then it will be spread across all the available GPUs.
|
||||
|
||||
## How can I enable Flash Attention?
|
||||
|
||||
Flash Attention is a feature of most modern models that can significantly reduce memory usage as the context size grows. To enable Flash Attention, set the `OLLAMA_FLASH_ATTENTION` environment variable to `1` when starting the Ollama server.
|
||||
|
||||
## How can I set the quantization type for the K/V cache?
|
||||
|
||||
The K/V context cache can be quantized to significantly reduce memory usage when Flash Attention is enabled.
|
||||
|
||||
To use quantized K/V cache with Ollama you can set the following environment variable:
|
||||
|
||||
- `OLLAMA_KV_CACHE_TYPE` - The quantization type for the K/V cache. Default is `f16`.
|
||||
|
||||
> Note: Currently this is a global option - meaning all models will run with the specified quantization type.
|
||||
|
||||
The currently available K/V cache quantization types are:
|
||||
|
||||
- `f16` - high precision and memory usage (default).
|
||||
- `q8_0` - 8-bit quantization, uses approximately 1/2 the memory of `f16` with a very small loss in precision, this usually has no noticeable impact on the model's quality (recommended if not using f16).
|
||||
- `q4_0` - 4-bit quantization, uses approximately 1/4 the memory of `f16` with a small-medium loss in precision that may be more noticeable at higher context sizes.
|
||||
|
||||
How much the cache quantization impacts the model's response quality will depend on the model and the task. Models that have a high GQA count (e.g. Qwen2) may see a larger impact on precision from quantization than models with a low GQA count.
|
||||
|
||||
You may need to experiment with different quantization types to find the best balance between memory usage and quality.
|
||||
|
||||
@@ -63,7 +63,7 @@ SYSTEM You are Mario from super mario bros, acting as an assistant.
|
||||
To use this:
|
||||
|
||||
1. Save it as a file (e.g. `Modelfile`)
|
||||
2. `ollama create choose-a-model-name -f <location of the file e.g. ./Modelfile>'`
|
||||
2. `ollama create choose-a-model-name -f <location of the file e.g. ./Modelfile>`
|
||||
3. `ollama run choose-a-model-name`
|
||||
4. Start using the model!
|
||||
|
||||
@@ -156,7 +156,7 @@ PARAMETER <parameter> <parametervalue>
|
||||
| seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | int | seed 42 |
|
||||
| stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate `stop` parameters in a modelfile. | string | stop "AI assistant:" |
|
||||
| tfs_z | Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1) | float | tfs_z 1 |
|
||||
| num_predict | Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context) | int | num_predict 42 |
|
||||
| num_predict | Maximum number of tokens to predict when generating text. (Default: -1, infinite generation) | int | num_predict 42 |
|
||||
| top_k | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | top_k 40 |
|
||||
| top_p | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | float | top_p 0.9 |
|
||||
| min_p | Alternative to the top_p, and aims to ensure a balance of quality and variety. The parameter *p* represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with *p*=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out. (Default: 0.0) | float | min_p 0.05 |
|
||||
|
||||
@@ -199,6 +199,8 @@ curl http://localhost:11434/v1/embeddings \
|
||||
- [x] `seed`
|
||||
- [x] `stop`
|
||||
- [x] `stream`
|
||||
- [x] `stream_options`
|
||||
- [x] `include_usage`
|
||||
- [x] `temperature`
|
||||
- [x] `top_p`
|
||||
- [x] `max_tokens`
|
||||
@@ -227,6 +229,8 @@ curl http://localhost:11434/v1/embeddings \
|
||||
- [x] `seed`
|
||||
- [x] `stop`
|
||||
- [x] `stream`
|
||||
- [x] `stream_options`
|
||||
- [x] `include_usage`
|
||||
- [x] `temperature`
|
||||
- [x] `top_p`
|
||||
- [x] `max_tokens`
|
||||
|
||||
@@ -153,6 +153,8 @@ var (
|
||||
Debug = Bool("OLLAMA_DEBUG")
|
||||
// FlashAttention enables the experimental flash attention feature.
|
||||
FlashAttention = Bool("OLLAMA_FLASH_ATTENTION")
|
||||
// KvCacheType is the quantization type for the K/V cache.
|
||||
KvCacheType = String("OLLAMA_KV_CACHE_TYPE")
|
||||
// NoHistory disables readline history.
|
||||
NoHistory = Bool("OLLAMA_NOHISTORY")
|
||||
// NoPrune disables pruning of model blobs on startup.
|
||||
@@ -234,6 +236,7 @@ func AsMap() map[string]EnvVar {
|
||||
ret := map[string]EnvVar{
|
||||
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
||||
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"},
|
||||
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
|
||||
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
|
||||
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
|
||||
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
|
||||
|
||||
@@ -93,7 +93,7 @@ make -j
|
||||
|
||||
## Vendoring
|
||||
|
||||
Ollama currently vendors [llama.cpp](https://github.com/ggerganov/llama.cpp/) and [ggml](https://github.com/ggerganov/ggml) through a vendoring model. While we generally strive to contribute changes back upstream to avoid drift, we cary a small set of patches which are applied to the tracking commit. A set of make targets are available to aid developers in updating to a newer tracking commit, or to work on changes.
|
||||
Ollama currently vendors [llama.cpp](https://github.com/ggerganov/llama.cpp/) and [ggml](https://github.com/ggerganov/ggml) through a vendoring model. While we generally strive to contribute changes back upstream to avoid drift, we cary a small set of patches which are applied to the tracking commit. A set of make targets are available to aid developers in updating to a newer tracking commit, or to work on changes.
|
||||
|
||||
If you update the vendoring code, start by running the following command to establish the tracking llama.cpp repo in the `./vendor/` directory.
|
||||
|
||||
@@ -105,35 +105,35 @@ make apply-patches
|
||||
|
||||
**Pin to new base commit**
|
||||
|
||||
To update to a newer base commit, select the upstream git tag or commit and update `llama/vendoring.env`
|
||||
To update to a newer base commit, select the upstream git tag or commit and update `llama/vendoring`
|
||||
|
||||
#### Applying patches
|
||||
|
||||
When updating to a newer base commit, the existing patches may not apply cleanly and require manual merge resolution.
|
||||
|
||||
Start by applying the patches. If any of the patches have conflicts, the `git am` will stop at the first failure.
|
||||
Start by applying the patches. If any of the patches have conflicts, the `git am` will stop at the first failure.
|
||||
|
||||
```
|
||||
make apply-patches
|
||||
```
|
||||
|
||||
If you see an error message about a conflict, go into the `./vendor/` directory, and perform merge resolution using your preferred tool to the patch commit which failed. Save the file(s) and continue the patch series with `git am --continue` . If any additional patches fail, follow the same pattern until the full patch series is applied. Once finished, run a final `create-patches` and `sync` target to ensure everything is updated.
|
||||
If you see an error message about a conflict, go into the `./vendor/` directory, and perform merge resolution using your preferred tool to the patch commit which failed. Save the file(s) and continue the patch series with `git am --continue` . If any additional patches fail, follow the same pattern until the full patch series is applied. Once finished, run a final `create-patches` and `sync` target to ensure everything is updated.
|
||||
|
||||
```
|
||||
make create-patches sync
|
||||
```
|
||||
|
||||
Build and test Ollama, and make any necessary changes to the Go code based on the new base commit. Submit your PR to the Ollama repo.
|
||||
Build and test Ollama, and make any necessary changes to the Go code based on the new base commit. Submit your PR to the Ollama repo.
|
||||
|
||||
### Generating Patches
|
||||
|
||||
When working on new fixes or features that impact vendored code, use the following model. First get a clean tracking repo with all current patches applied:
|
||||
When working on new fixes or features that impact vendored code, use the following model. First get a clean tracking repo with all current patches applied:
|
||||
|
||||
```
|
||||
make apply-patches
|
||||
```
|
||||
|
||||
Now edit the upstream native code in the `./vendor/` directory. You do not need to commit every change in order to build, a dirty working tree in the tracking repo is OK while developing. Simply save in your editor, and run the following to refresh the vendored code with your changes, build the backend(s) and build ollama:
|
||||
Now edit the upstream native code in the `./vendor/` directory. You do not need to commit every change in order to build, a dirty working tree in the tracking repo is OK while developing. Simply save in your editor, and run the following to refresh the vendored code with your changes, build the backend(s) and build ollama:
|
||||
|
||||
```
|
||||
make sync
|
||||
@@ -142,9 +142,9 @@ go build .
|
||||
```
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Do **NOT** run `apply-patches` while you're iterating as that will reset the tracking repo. It will detect a dirty tree and abort, but if your tree is clean and you accidentally ran this target, use `git reflog` to recover your commit(s).
|
||||
> Do **NOT** run `apply-patches` while you're iterating as that will reset the tracking repo. It will detect a dirty tree and abort, but if your tree is clean and you accidentally ran this target, use `git reflog` to recover your commit(s).
|
||||
|
||||
Iterate until you're ready to submit PRs. Once your code is ready, commit a change in the `./vendor/` directory, then generate the patches for ollama with
|
||||
Iterate until you're ready to submit PRs. Once your code is ready, commit a change in the `./vendor/` directory, then generate the patches for ollama with
|
||||
|
||||
```
|
||||
make create-patches
|
||||
@@ -157,4 +157,4 @@ In your `./vendor/` directory, create a branch, and cherry-pick the new commit t
|
||||
|
||||
Commit the changes in the ollama repo and submit a PR to Ollama, which will include the vendored code update with your change, along with the patches.
|
||||
|
||||
After your PR upstream is merged, follow the **Updating Base Commit** instructions above, however first remove your patch before running `apply-patches` since the new base commit contains your change already.
|
||||
After your PR upstream is merged, follow the **Updating Base Commit** instructions above, however first remove your patch before running `apply-patches` since the new base commit contains your change already.
|
||||
|
||||
@@ -85,9 +85,12 @@ COMPILER inline get_compiler() {
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"runtime"
|
||||
"runtime/cgo"
|
||||
"slices"
|
||||
@@ -140,7 +143,7 @@ type ContextParams struct {
|
||||
c C.struct_llama_context_params
|
||||
}
|
||||
|
||||
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool) ContextParams {
|
||||
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams {
|
||||
params := C.llama_context_default_params()
|
||||
params.n_ctx = C.uint(numCtx)
|
||||
params.n_batch = C.uint(batchSize)
|
||||
@@ -149,9 +152,28 @@ func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, fla
|
||||
params.n_threads_batch = params.n_threads
|
||||
params.embeddings = C.bool(true)
|
||||
params.flash_attn = C.bool(flashAttention)
|
||||
params.type_k = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
|
||||
params.type_v = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
|
||||
|
||||
return ContextParams{c: params}
|
||||
}
|
||||
|
||||
// kvCacheTypeFromStr converts a string cache type to the corresponding GGML type value
|
||||
func kvCacheTypeFromStr(s string) C.enum_ggml_type {
|
||||
if s == "" {
|
||||
return C.GGML_TYPE_F16
|
||||
}
|
||||
|
||||
switch s {
|
||||
case "q8_0":
|
||||
return C.GGML_TYPE_Q8_0
|
||||
case "q4_0":
|
||||
return C.GGML_TYPE_Q4_0
|
||||
default:
|
||||
return C.GGML_TYPE_F16
|
||||
}
|
||||
}
|
||||
|
||||
type Context struct {
|
||||
c *C.struct_llama_context
|
||||
numThreads int
|
||||
@@ -680,3 +702,33 @@ func (s *SamplingContext) Sample(llamaContext *Context, idx int) int {
|
||||
func (s *SamplingContext) Accept(id int, applyGrammar bool) {
|
||||
C.gpt_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar))
|
||||
}
|
||||
|
||||
type JsonSchema struct {
|
||||
Defs map[string]any `json:"$defs,omitempty"`
|
||||
Properties map[string]any `json:"properties,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
func (js JsonSchema) AsGrammar() string {
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(js); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
cStr := C.CString(b.String())
|
||||
defer C.free(unsafe.Pointer(cStr))
|
||||
|
||||
// Allocate buffer for grammar output with reasonable size
|
||||
const maxLen = 32768 // 32KB
|
||||
buf := make([]byte, maxLen)
|
||||
|
||||
// Call C function to convert schema to grammar
|
||||
length := C.schema_to_grammar(cStr, (*C.char)(unsafe.Pointer(&buf[0])), C.size_t(maxLen))
|
||||
if length == 0 {
|
||||
slog.Warn("unable to convert schema to grammar")
|
||||
}
|
||||
|
||||
return string(buf[:length])
|
||||
}
|
||||
|
||||
@@ -1 +1,70 @@
|
||||
package llama
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestJsonSchema(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
schema JsonSchema
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty schema",
|
||||
schema: JsonSchema{
|
||||
Type: "object",
|
||||
},
|
||||
expected: `array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
null ::= "null" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
root ::= object
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
value ::= object | array | string | number | boolean | null`,
|
||||
},
|
||||
{
|
||||
name: "invalid schema with circular reference",
|
||||
schema: JsonSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"self": map[string]any{
|
||||
"$ref": "#", // Self reference
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "", // Should return empty string for invalid schema
|
||||
},
|
||||
{
|
||||
name: "schema with invalid type",
|
||||
schema: JsonSchema{
|
||||
Type: "invalid_type", // Invalid type
|
||||
Properties: map[string]any{
|
||||
"foo": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "", // Should return empty string for invalid schema
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := tc.schema.AsGrammar()
|
||||
if !strings.EqualFold(strings.TrimSpace(result), strings.TrimSpace(tc.expected)) {
|
||||
if diff := cmp.Diff(tc.expected, result); diff != "" {
|
||||
t.Fatalf("grammar mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -850,6 +850,7 @@ func (s *Server) loadModel(
|
||||
lpath multiLPath,
|
||||
ppath string,
|
||||
kvSize int,
|
||||
kvCacheType string,
|
||||
flashAttention bool,
|
||||
threads int,
|
||||
multiUserCache bool,
|
||||
@@ -862,7 +863,7 @@ func (s *Server) loadModel(
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention)
|
||||
ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, kvCacheType)
|
||||
s.lc, err = llama.NewContextWithModel(s.model, ctxParams)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -903,6 +904,7 @@ func main() {
|
||||
mainGpu := flag.Int("main-gpu", 0, "Main GPU")
|
||||
flashAttention := flag.Bool("flash-attn", false, "Enable flash attention")
|
||||
kvSize := flag.Int("ctx-size", 2048, "Context (or KV cache) size")
|
||||
kvCacheType := flag.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
|
||||
port := flag.Int("port", 8080, "Port to expose the server on")
|
||||
threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
||||
verbose := flag.Bool("verbose", false, "verbose output (default: disabled)")
|
||||
@@ -970,7 +972,7 @@ func main() {
|
||||
}
|
||||
|
||||
server.ready.Add(1)
|
||||
go server.loadModel(params, *mpath, lpaths, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache)
|
||||
go server.loadModel(params, *mpath, lpaths, *ppath, *kvSize, *kvCacheType, *flashAttention, *threads, *multiUserCache)
|
||||
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
|
||||
|
||||
29
llama/sampling_ext.cpp
vendored
29
llama/sampling_ext.cpp
vendored
@@ -1,11 +1,13 @@
|
||||
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
|
||||
#include "sampling.h"
|
||||
#include "sampling_ext.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
|
||||
struct gpt_sampler *gpt_sampler_cinit(
|
||||
const struct llama_model *model, struct gpt_sampler_cparams *params)
|
||||
{
|
||||
try {
|
||||
try
|
||||
{
|
||||
gpt_sampler_params sparams;
|
||||
sparams.top_k = params->top_k;
|
||||
sparams.top_p = params->top_p;
|
||||
@@ -24,7 +26,9 @@ struct gpt_sampler *gpt_sampler_cinit(
|
||||
sparams.seed = params->seed;
|
||||
sparams.grammar = params->grammar;
|
||||
return gpt_sampler_init(model, sparams);
|
||||
} catch (const std::exception & err) {
|
||||
}
|
||||
catch (const std::exception &err)
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
@@ -54,3 +58,24 @@ void gpt_sampler_caccept(
|
||||
{
|
||||
gpt_sampler_accept(sampler, id, apply_grammar);
|
||||
}
|
||||
|
||||
int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len)
|
||||
{
|
||||
try
|
||||
{
|
||||
nlohmann::json schema = nlohmann::json::parse(json_schema);
|
||||
std::string grammar_str = json_schema_to_grammar(schema);
|
||||
size_t len = grammar_str.length();
|
||||
if (len >= max_len)
|
||||
{
|
||||
len = max_len - 1;
|
||||
}
|
||||
strncpy(grammar, grammar_str.c_str(), len);
|
||||
return len;
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
strncpy(grammar, "", max_len - 1);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
2
llama/sampling_ext.h
vendored
2
llama/sampling_ext.h
vendored
@@ -47,6 +47,8 @@ extern "C"
|
||||
llama_token id,
|
||||
bool apply_grammar);
|
||||
|
||||
int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
36
llm/ggml.go
36
llm/ggml.go
@@ -360,7 +360,7 @@ func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
||||
}, offset, nil
|
||||
}
|
||||
|
||||
func (llm GGML) GraphSize(context, batch uint64) (kv, partialOffload, fullOffload uint64) {
|
||||
func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
|
||||
embedding := llm.KV().EmbeddingLength()
|
||||
heads := llm.KV().HeadCount()
|
||||
headsKV := llm.KV().HeadCountKV()
|
||||
@@ -372,7 +372,8 @@ func (llm GGML) GraphSize(context, batch uint64) (kv, partialOffload, fullOffloa
|
||||
|
||||
layers := llm.Tensors().Layers()
|
||||
|
||||
kv = 2 * context * llm.KV().BlockCount() * (embeddingHeadsK + embeddingHeadsV) * headsKV
|
||||
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
||||
kv = uint64(float64(context*llm.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
|
||||
switch llm.KV().Architecture() {
|
||||
case "llama":
|
||||
@@ -527,3 +528,34 @@ func (llm GGML) GraphSize(context, batch uint64) (kv, partialOffload, fullOffloa
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// SupportsKVCacheType checks if the requested cache type is supported
|
||||
func (ggml GGML) SupportsKVCacheType(cacheType string) bool {
|
||||
validKVCacheTypes := []string{"f16", "q8_0", "q4_0"}
|
||||
return slices.Contains(validKVCacheTypes, cacheType)
|
||||
}
|
||||
|
||||
// SupportsFlashAttention checks if the model supports flash attention
|
||||
func (ggml GGML) SupportsFlashAttention() bool {
|
||||
_, isEmbedding := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]
|
||||
if isEmbedding {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check head counts match and are non-zero
|
||||
headCountK := ggml.KV().EmbeddingHeadCountK()
|
||||
headCountV := ggml.KV().EmbeddingHeadCountV()
|
||||
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
|
||||
}
|
||||
|
||||
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
|
||||
func kvCacheBytesPerElement(cacheType string) float64 {
|
||||
switch cacheType {
|
||||
case "q8_0":
|
||||
return 1 // 1/2 of fp16
|
||||
case "q4_0":
|
||||
return 0.5 // 1/4 of fp16
|
||||
default:
|
||||
return 2 // f16 (default)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -123,7 +123,23 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string,
|
||||
slog.Warn("model missing blk.0 layer size")
|
||||
}
|
||||
|
||||
kv, graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
|
||||
fa := envconfig.FlashAttention() &&
|
||||
discover.GetGPUInfo().FlashAttentionSupported() &&
|
||||
ggml.SupportsFlashAttention()
|
||||
|
||||
var kvct string
|
||||
if fa {
|
||||
requested := strings.ToLower(envconfig.KvCacheType())
|
||||
if requested != "" && ggml.SupportsKVCacheType(requested) {
|
||||
kvct = requested
|
||||
}
|
||||
}
|
||||
|
||||
kv, graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), kvct)
|
||||
|
||||
// KV is proportional to the number of layers
|
||||
layerSize += kv / ggml.KV().BlockCount()
|
||||
|
||||
if graphPartialOffload == 0 {
|
||||
graphPartialOffload = ggml.KV().GQA() * kv / 6
|
||||
}
|
||||
@@ -131,9 +147,6 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string,
|
||||
graphFullOffload = graphPartialOffload
|
||||
}
|
||||
|
||||
// KV is proportional to the number of layers
|
||||
layerSize += kv / ggml.KV().BlockCount()
|
||||
|
||||
// on metal there's no partial offload overhead
|
||||
if gpus[0].Library == "metal" {
|
||||
graphPartialOffload = graphFullOffload
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
func TestEstimateGPULayers(t *testing.T) {
|
||||
t.Setenv("OLLAMA_DEBUG", "1")
|
||||
t.Setenv("OLLAMA_KV_CACHE_TYPE", "") // Ensure default f16
|
||||
|
||||
modelName := "dummy"
|
||||
f, err := os.CreateTemp(t.TempDir(), modelName)
|
||||
|
||||
@@ -214,15 +214,36 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
|
||||
params = append(params, "--threads", strconv.Itoa(defaultThreads))
|
||||
}
|
||||
|
||||
flashAttnEnabled := envconfig.FlashAttention()
|
||||
fa := envconfig.FlashAttention()
|
||||
if fa && !gpus.FlashAttentionSupported() {
|
||||
slog.Warn("flash attention enabled but not supported by gpu")
|
||||
fa = false
|
||||
}
|
||||
|
||||
for _, g := range gpus {
|
||||
// only cuda (compute capability 7+) and metal support flash attention
|
||||
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
|
||||
flashAttnEnabled = false
|
||||
if fa && !ggml.SupportsFlashAttention() {
|
||||
slog.Warn("flash attention enabled but not supported by model")
|
||||
fa = false
|
||||
}
|
||||
|
||||
kvct := strings.ToLower(envconfig.KvCacheType())
|
||||
|
||||
if fa {
|
||||
slog.Info("enabling flash attention")
|
||||
params = append(params, "--flash-attn")
|
||||
|
||||
// Flash Attention also supports kv cache quantization
|
||||
// Enable if the requested and kv cache type is supported by the model
|
||||
if kvct != "" && ggml.SupportsKVCacheType(kvct) {
|
||||
params = append(params, "--kv-cache-type", kvct)
|
||||
} else {
|
||||
slog.Warn("kv cache type not supported by model", "type", kvct)
|
||||
}
|
||||
} else if kvct != "" && kvct != "f16" {
|
||||
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
|
||||
}
|
||||
|
||||
// mmap has issues with partial offloading on metal
|
||||
// mmap has issues with partial offloading on metal
|
||||
for _, g := range gpus {
|
||||
if g.Library == "metal" &&
|
||||
uint64(opts.NumGPU) > 0 &&
|
||||
uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
|
||||
@@ -231,10 +252,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
|
||||
}
|
||||
}
|
||||
|
||||
if flashAttnEnabled {
|
||||
params = append(params, "--flash-attn")
|
||||
}
|
||||
|
||||
// Windows CUDA should not use mmap for best performance
|
||||
// Linux with a model larger than free space, mmap leads to thrashing
|
||||
// For CPU loads we want the memory to be allocated, not FS cache
|
||||
@@ -617,27 +634,22 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
||||
const jsonGrammar = `
|
||||
root ::= object
|
||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||
|
||||
object ::=
|
||||
"{" ws (
|
||||
string ":" ws value
|
||||
("," ws string ":" ws value)*
|
||||
)? "}" ws
|
||||
|
||||
array ::=
|
||||
"[" ws (
|
||||
value
|
||||
("," ws value)*
|
||||
)? "]" ws
|
||||
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\\x7F\x00-\x1F] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
)* "\"" ws
|
||||
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||
|
||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||
ws ::= ([ \t\n] ws)?
|
||||
`
|
||||
@@ -667,7 +679,7 @@ type completion struct {
|
||||
|
||||
type CompletionRequest struct {
|
||||
Prompt string
|
||||
Format string
|
||||
Format json.RawMessage
|
||||
Images []ImageData
|
||||
Options *api.Options
|
||||
}
|
||||
@@ -732,10 +744,22 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
return fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
}
|
||||
|
||||
if req.Format == "json" {
|
||||
request["grammar"] = jsonGrammar
|
||||
if !strings.Contains(strings.ToLower(req.Prompt), "json") {
|
||||
slog.Warn("Prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
|
||||
// TODO (parthsareen): Move conversion to grammar with sampling logic
|
||||
// API should do error handling for invalid formats
|
||||
if req.Format != nil {
|
||||
if strings.ToLower(strings.TrimSpace(string(req.Format))) == `"json"` {
|
||||
request["grammar"] = jsonGrammar
|
||||
if !strings.Contains(strings.ToLower(req.Prompt), "json") {
|
||||
slog.Warn("prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
|
||||
}
|
||||
} else if schema, err := func() (llama.JsonSchema, error) {
|
||||
var schema llama.JsonSchema
|
||||
err := json.Unmarshal(req.Format, &schema)
|
||||
return schema, err
|
||||
}(); err == nil {
|
||||
request["grammar"] = schema.AsGrammar()
|
||||
} else {
|
||||
slog.Warn(`format is neither a schema or "json"`, "format", req.Format)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
143
openai/openai.go
143
openai/openai.go
@@ -62,7 +62,12 @@ type Usage struct {
|
||||
}
|
||||
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type"`
|
||||
Type string `json:"type"`
|
||||
JsonSchema *JsonSchema `json:"json_schema,omitempty"`
|
||||
}
|
||||
|
||||
type JsonSchema struct {
|
||||
Schema map[string]any `json:"schema"`
|
||||
}
|
||||
|
||||
type EmbedRequest struct {
|
||||
@@ -70,10 +75,15 @@ type EmbedRequest struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
type StreamOptions struct {
|
||||
IncludeUsage bool `json:"include_usage"`
|
||||
}
|
||||
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
StreamOptions *StreamOptions `json:"stream_options"`
|
||||
MaxTokens *int `json:"max_tokens"`
|
||||
Seed *int `json:"seed"`
|
||||
Stop any `json:"stop"`
|
||||
@@ -102,21 +112,23 @@ type ChatCompletionChunk struct {
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint string `json:"system_fingerprint"`
|
||||
Choices []ChunkChoice `json:"choices"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
|
||||
type CompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty"`
|
||||
MaxTokens *int `json:"max_tokens"`
|
||||
PresencePenalty float32 `json:"presence_penalty"`
|
||||
Seed *int `json:"seed"`
|
||||
Stop any `json:"stop"`
|
||||
Stream bool `json:"stream"`
|
||||
Temperature *float32 `json:"temperature"`
|
||||
TopP float32 `json:"top_p"`
|
||||
Suffix string `json:"suffix"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty"`
|
||||
MaxTokens *int `json:"max_tokens"`
|
||||
PresencePenalty float32 `json:"presence_penalty"`
|
||||
Seed *int `json:"seed"`
|
||||
Stop any `json:"stop"`
|
||||
Stream bool `json:"stream"`
|
||||
StreamOptions *StreamOptions `json:"stream_options"`
|
||||
Temperature *float32 `json:"temperature"`
|
||||
TopP float32 `json:"top_p"`
|
||||
Suffix string `json:"suffix"`
|
||||
}
|
||||
|
||||
type Completion struct {
|
||||
@@ -136,10 +148,12 @@ type CompletionChunk struct {
|
||||
Choices []CompleteChunkChoice `json:"choices"`
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint string `json:"system_fingerprint"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Index int `json:"index"`
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
@@ -191,6 +205,14 @@ func NewError(code int, message string) ErrorResponse {
|
||||
return ErrorResponse{Error{Type: etype, Message: message}}
|
||||
}
|
||||
|
||||
func toUsage(r api.ChatResponse) Usage {
|
||||
return Usage{
|
||||
PromptTokens: r.PromptEvalCount,
|
||||
CompletionTokens: r.EvalCount,
|
||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||
}
|
||||
}
|
||||
|
||||
func toolCallId() string {
|
||||
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b := make([]byte, 8)
|
||||
@@ -206,6 +228,7 @@ func toToolCalls(tc []api.ToolCall) []ToolCall {
|
||||
toolCalls[i].ID = toolCallId()
|
||||
toolCalls[i].Type = "function"
|
||||
toolCalls[i].Function.Name = tc.Function.Name
|
||||
toolCalls[i].Index = tc.Function.Index
|
||||
|
||||
args, err := json.Marshal(tc.Function.Arguments)
|
||||
if err != nil {
|
||||
@@ -239,11 +262,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||
return nil
|
||||
}(r.DoneReason),
|
||||
}},
|
||||
Usage: Usage{
|
||||
PromptTokens: r.PromptEvalCount,
|
||||
CompletionTokens: r.EvalCount,
|
||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||
},
|
||||
Usage: toUsage(r),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -268,6 +287,14 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
||||
}
|
||||
}
|
||||
|
||||
func toUsageGenerate(r api.GenerateResponse) Usage {
|
||||
return Usage{
|
||||
PromptTokens: r.PromptEvalCount,
|
||||
CompletionTokens: r.EvalCount,
|
||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||
}
|
||||
}
|
||||
|
||||
func toCompletion(id string, r api.GenerateResponse) Completion {
|
||||
return Completion{
|
||||
Id: id,
|
||||
@@ -285,11 +312,7 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
|
||||
return nil
|
||||
}(r.DoneReason),
|
||||
}},
|
||||
Usage: Usage{
|
||||
PromptTokens: r.PromptEvalCount,
|
||||
CompletionTokens: r.EvalCount,
|
||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||
},
|
||||
Usage: toUsageGenerate(r),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -480,9 +503,21 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
options["top_p"] = 1.0
|
||||
}
|
||||
|
||||
var format string
|
||||
if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" {
|
||||
format = "json"
|
||||
var format json.RawMessage
|
||||
if r.ResponseFormat != nil {
|
||||
switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
|
||||
// Support the old "json_object" type for OpenAI compatibility
|
||||
case "json_object":
|
||||
format = json.RawMessage(`"json"`)
|
||||
case "json_schema":
|
||||
if r.ResponseFormat.JsonSchema != nil {
|
||||
schema, err := json.Marshal(r.ResponseFormat.JsonSchema.Schema)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal json schema: %w", err)
|
||||
}
|
||||
format = schema
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &api.ChatRequest{
|
||||
@@ -551,14 +586,16 @@ type BaseWriter struct {
|
||||
}
|
||||
|
||||
type ChatWriter struct {
|
||||
stream bool
|
||||
id string
|
||||
stream bool
|
||||
streamOptions *StreamOptions
|
||||
id string
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
type CompleteWriter struct {
|
||||
stream bool
|
||||
id string
|
||||
stream bool
|
||||
streamOptions *StreamOptions
|
||||
id string
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
@@ -601,7 +638,11 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||
|
||||
// chat chunk
|
||||
if w.stream {
|
||||
d, err := json.Marshal(toChunk(w.id, chatResponse))
|
||||
c := toChunk(w.id, chatResponse)
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
c.Usage = &Usage{}
|
||||
}
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -613,6 +654,17 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
if chatResponse.Done {
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
u := toUsage(chatResponse)
|
||||
d, err := json.Marshal(ChatCompletionChunk{Choices: []ChunkChoice{}, Usage: &u})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -650,7 +702,11 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
||||
|
||||
// completion chunk
|
||||
if w.stream {
|
||||
d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
|
||||
c := toCompleteChunk(w.id, generateResponse)
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
c.Usage = &Usage{}
|
||||
}
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -662,6 +718,17 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
if generateResponse.Done {
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
u := toUsageGenerate(generateResponse)
|
||||
d, err := json.Marshal(CompletionChunk{Choices: []CompleteChunkChoice{}, Usage: &u})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -824,9 +891,10 @@ func CompletionsMiddleware() gin.HandlerFunc {
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &CompleteWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
||||
streamOptions: req.StreamOptions,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
@@ -906,9 +974,10 @@ func ChatMiddleware() gin.HandlerFunc {
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &ChatWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
||||
streamOptions: req.StreamOptions,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
@@ -107,7 +108,46 @@ func TestChatMiddleware(t *testing.T) {
|
||||
"presence_penalty": 5.0,
|
||||
"top_p": 6.0,
|
||||
},
|
||||
Format: "json",
|
||||
Format: json.RawMessage(`"json"`),
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with streaming usage",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
"stream": true,
|
||||
"stream_options": {"include_usage": true},
|
||||
"max_tokens": 999,
|
||||
"seed": 123,
|
||||
"stop": ["\n", "stop"],
|
||||
"temperature": 3.0,
|
||||
"frequency_penalty": 4.0,
|
||||
"presence_penalty": 5.0,
|
||||
"top_p": 6.0,
|
||||
"response_format": {"type": "json_object"}
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Hello",
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"num_predict": 999.0, // float because JSON doesn't distinguish between float and int
|
||||
"seed": 123.0,
|
||||
"stop": []any{"\n", "stop"},
|
||||
"temperature": 3.0,
|
||||
"frequency_penalty": 4.0,
|
||||
"presence_penalty": 5.0,
|
||||
"top_p": 6.0,
|
||||
},
|
||||
Format: json.RawMessage(`"json"`),
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
@@ -195,7 +235,86 @@ func TestChatMiddleware(t *testing.T) {
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "chat handler with streaming tools",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather like in Paris?"}
|
||||
],
|
||||
"stream": true,
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": ["location"],
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What's the weather like in Paris?",
|
||||
},
|
||||
},
|
||||
Tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
} `json:"properties"`
|
||||
}{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}{
|
||||
"location": {
|
||||
Type: "string",
|
||||
Description: "The city and state",
|
||||
},
|
||||
"unit": {
|
||||
Type: "string",
|
||||
Enum: []string{"celsius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler error forwarding",
|
||||
body: `{
|
||||
@@ -237,13 +356,13 @@ func TestChatMiddleware(t *testing.T) {
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||
t.Fatal("requests did not match")
|
||||
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
|
||||
t.Fatalf("requests did not match: %+v", diff)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tc.err, errResp) {
|
||||
t.Fatal("errors did not match")
|
||||
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
||||
t.Fatalf("errors did not match for %s:\n%s", tc.name, diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -283,6 +402,55 @@ func TestCompletionsMiddleware(t *testing.T) {
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "completions handler stream",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "Hello",
|
||||
"stream": true,
|
||||
"temperature": 0.8,
|
||||
"stop": ["\n", "stop"],
|
||||
"suffix": "suffix"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
Options: map[string]any{
|
||||
"frequency_penalty": 0.0,
|
||||
"presence_penalty": 0.0,
|
||||
"temperature": 0.8,
|
||||
"top_p": 1.0,
|
||||
"stop": []any{"\n", "stop"},
|
||||
},
|
||||
Suffix: "suffix",
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "completions handler stream with usage",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "Hello",
|
||||
"stream": true,
|
||||
"stream_options": {"include_usage": true},
|
||||
"temperature": 0.8,
|
||||
"stop": ["\n", "stop"],
|
||||
"suffix": "suffix"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
Options: map[string]any{
|
||||
"frequency_penalty": 0.0,
|
||||
"presence_penalty": 0.0,
|
||||
"temperature": 0.8,
|
||||
"top_p": 1.0,
|
||||
"stop": []any{"\n", "stop"},
|
||||
},
|
||||
Suffix: "suffix",
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "completions handler error forwarding",
|
||||
body: `{
|
||||
|
||||
@@ -148,10 +148,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Format != "" && req.Format != "json" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
|
||||
return
|
||||
} else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
|
||||
if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
|
||||
return
|
||||
}
|
||||
@@ -251,6 +248,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
|
||||
var b bytes.Buffer
|
||||
if req.Context != nil {
|
||||
slog.Warn("the context field is deprecated and will be removed in a future version of Ollama")
|
||||
s, err := r.Detokenize(c.Request.Context(), req.Context)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -1469,7 +1467,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
go func() {
|
||||
defer close(ch)
|
||||
var sb strings.Builder
|
||||
var hasToolCalls bool
|
||||
var toolCallIndex int = 0
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
@@ -1509,16 +1507,19 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
sb.WriteString(r.Content)
|
||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||
res.Message.ToolCalls = toolCalls
|
||||
for i := range toolCalls {
|
||||
toolCalls[i].Function.Index = toolCallIndex
|
||||
toolCallIndex++
|
||||
}
|
||||
res.Message.Content = ""
|
||||
sb.Reset()
|
||||
hasToolCalls = true
|
||||
ch <- res
|
||||
return
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
// Send any remaining content if no tool calls were detected
|
||||
if !hasToolCalls {
|
||||
if toolCallIndex == 0 {
|
||||
res.Message.Content = sb.String()
|
||||
}
|
||||
ch <- res
|
||||
|
||||
Reference in New Issue
Block a user