Compare commits

...

21 Commits

Author SHA1 Message Date
ParthSareen
2536ffe0ab More cleanup 2024-12-11 18:11:00 -08:00
ParthSareen
97abd7bfea Code cleanup 2024-12-11 18:04:16 -08:00
Anuraag Agrawal
c6509bf76e Merge branch 'main' of https://github.com/ollama/ollama into openai-stream-usage 2024-12-06 12:05:25 +09:00
Jeffrey Morgan
aed1419c64 ci: skip go build for tests (#7899) 2024-12-04 21:22:36 -08:00
Parth Sareen
c6c526275d api: add generate endpoint for structured outputs (#7939) 2024-12-04 17:37:12 -08:00
Parth Sareen
630e7dc6ff api: structured outputs - chat endpoint (#7900)
Adds structured outputs to chat endpoint
---------

Co-authored-by: Michael Yang <mxyng@pm.me>
Co-authored-by: Hieu Nguyen <hieunguyen1053@outlook.com>
2024-12-04 16:31:19 -08:00
Michael Yang
eb8366d658 Merge pull request #7932 from ollama/mxyng/fix-merges 2024-12-04 10:04:52 -08:00
Michael Yang
4456012956 fix unmarshaling merges 2024-12-04 09:21:56 -08:00
Sam
539be43640 llm: normalise kvct parameter handling (#7926) 2024-12-03 16:30:40 -08:00
Sam
1bdab9fdb1 llm: introduce k/v context quantization (vRAM improvements) (#6279) 2024-12-03 15:57:19 -08:00
owboson
2b82c5a8a1 docs: correct default num_predict value in modelfile.md (#7693) 2024-12-03 15:00:05 -08:00
Tigran
55c3efa900 docs: remove extra quote in modelfile.md (#7908) 2024-12-02 09:28:56 -08:00
David Mayboroda
1aedffad93 readme: add minima to community integrations (#7906) 2024-12-02 01:14:47 -08:00
Jeffrey Morgan
ff6c2d6dc8 cmd: don't rely on reading repo file for test (#7898) 2024-11-30 14:12:53 -08:00
Jeffrey Morgan
d543b282a7 server: add warning message for deprecated context field (#7878) 2024-11-30 14:05:50 -08:00
Parth Sareen
5f8051180e Enable index tracking for tools - openai api support (#7888) 2024-11-29 20:00:09 -08:00
Jeffrey Morgan
39e29ae5dd llama: fix typo and formatting in readme (#7876) 2024-11-28 17:27:11 -08:00
TheCookingSenpai
30a9f063c9 readme: add SpaceLlama, YouLama, and DualMind to community integrations (#7216) 2024-11-28 15:16:27 -08:00
Anuraag Agrawal
7355ab3703 Return empty choices on usage chunk 2024-10-03 13:02:50 +09:00
Anuraag Agrawal
7ed81437fe Document stream_options 2024-09-17 15:25:31 +09:00
Anuraag Agrawal
220108d3f4 openai: support include_usage stream option to return final usage chunk 2024-09-13 12:32:05 +09:00
26 changed files with 695 additions and 116 deletions

View File

@@ -310,8 +310,7 @@ jobs:
arm64) echo ARCH=arm64 ;; arm64) echo ARCH=arm64 ;;
esac >>$GITHUB_ENV esac >>$GITHUB_ENV
shell: bash shell: bash
- run: go build - run: go test ./...
- run: go test -v ./...
patches: patches:
needs: [changes] needs: [changes]

View File

@@ -346,6 +346,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page) - [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
- [Promptery](https://github.com/promptery/promptery) (desktop client for Ollama.) - [Promptery](https://github.com/promptery/promptery) (desktop client for Ollama.)
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama) - [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
- [SpaceLlama](https://github.com/tcsenpai/spacellama) (Firefox and Chrome extension to quickly summarize web pages with ollama in a sidebar)
- [YouLama](https://github.com/tcsenpai/youlama) (Webapp to quickly summarize any YouTube video, supporting Invidious as well)
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
- [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol) - [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol)
- [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app) - [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app)
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard and said in the meetings) - [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard and said in the meetings)
@@ -356,6 +359,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Nosia](https://github.com/nosia-ai/nosia) (Easy to install and use RAG platform based on Ollama) - [Nosia](https://github.com/nosia-ai/nosia) (Easy to install and use RAG platform based on Ollama)
- [Witsy](https://github.com/nbonamy/witsy) (An AI Desktop application avaiable for Mac/Windows/Linux) - [Witsy](https://github.com/nbonamy/witsy) (An AI Desktop application avaiable for Mac/Windows/Linux)
- [Abbey](https://github.com/US-Artificial-Intelligence/abbey) (A configurable AI interface server with notebooks, document storage, and YouTube support) - [Abbey](https://github.com/US-Artificial-Intelligence/abbey) (A configurable AI interface server with notebooks, document storage, and YouTube support)
- [Minima](https://github.com/dmayboroda/minima) (RAG with on-premises or fully local workflow)
### Cloud ### Cloud

View File

@@ -67,7 +67,7 @@ type GenerateRequest struct {
Raw bool `json:"raw,omitempty"` Raw bool `json:"raw,omitempty"`
// Format specifies the format to return a response in. // Format specifies the format to return a response in.
Format string `json:"format"` Format json.RawMessage `json:"format,omitempty"`
// KeepAlive controls how long the model will stay loaded in memory following // KeepAlive controls how long the model will stay loaded in memory following
// this request. // this request.
@@ -94,7 +94,7 @@ type ChatRequest struct {
Stream *bool `json:"stream,omitempty"` Stream *bool `json:"stream,omitempty"`
// Format is the format to return the response in (e.g. "json"). // Format is the format to return the response in (e.g. "json").
Format string `json:"format"` Format json.RawMessage `json:"format,omitempty"`
// KeepAlive controls how long the model will stay loaded into memory // KeepAlive controls how long the model will stay loaded into memory
// following the request. // following the request.
@@ -146,6 +146,7 @@ type ToolCall struct {
} }
type ToolCallFunction struct { type ToolCallFunction struct {
Index int `json:"index,omitempty"`
Name string `json:"name"` Name string `json:"name"`
Arguments ToolCallFunctionArguments `json:"arguments"` Arguments ToolCallFunctionArguments `json:"arguments"`
} }

View File

@@ -8,6 +8,7 @@ import (
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/json"
"encoding/pem" "encoding/pem"
"errors" "errors"
"fmt" "fmt"
@@ -1038,7 +1039,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
req := &api.ChatRequest{ req := &api.ChatRequest{
Model: opts.Model, Model: opts.Model,
Messages: opts.Messages, Messages: opts.Messages,
Format: opts.Format, Format: json.RawMessage(opts.Format),
Options: opts.Options, Options: opts.Options,
} }
@@ -1125,7 +1126,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
Prompt: opts.Prompt, Prompt: opts.Prompt,
Context: generateContext, Context: generateContext,
Images: opts.Images, Images: opts.Images,
Format: opts.Format, Format: json.RawMessage(opts.Format),
System: opts.System, System: opts.System,
Options: opts.Options, Options: opts.Options,
KeepAlive: opts.KeepAlive, KeepAlive: opts.KeepAlive,
@@ -1445,6 +1446,7 @@ func NewCLI() *cobra.Command {
envVars["OLLAMA_SCHED_SPREAD"], envVars["OLLAMA_SCHED_SPREAD"],
envVars["OLLAMA_TMPDIR"], envVars["OLLAMA_TMPDIR"],
envVars["OLLAMA_FLASH_ATTENTION"], envVars["OLLAMA_FLASH_ATTENTION"],
envVars["OLLAMA_KV_CACHE_TYPE"],
envVars["OLLAMA_LLM_LIBRARY"], envVars["OLLAMA_LLM_LIBRARY"],
envVars["OLLAMA_GPU_OVERHEAD"], envVars["OLLAMA_GPU_OVERHEAD"],
envVars["OLLAMA_LOAD_TIMEOUT"], envVars["OLLAMA_LOAD_TIMEOUT"],

View File

@@ -8,7 +8,6 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
"path/filepath"
"strings" "strings"
"testing" "testing"
@@ -180,18 +179,14 @@ Weigh anchor!
t.Run("license", func(t *testing.T) { t.Run("license", func(t *testing.T) {
var b bytes.Buffer var b bytes.Buffer
license, err := os.ReadFile(filepath.Join("..", "LICENSE")) license := "MIT License\nCopyright (c) Ollama\n"
if err != nil {
t.Fatal(err)
}
if err := showInfo(&api.ShowResponse{ if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{ Details: api.ModelDetails{
Family: "test", Family: "test",
ParameterSize: "7B", ParameterSize: "7B",
QuantizationLevel: "FP16", QuantizationLevel: "FP16",
}, },
License: string(license), License: license,
}, &b); err != nil { }, &b); err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -10,6 +10,7 @@ import (
"log/slog" "log/slog"
"os" "os"
"slices" "slices"
"strings"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
) )
@@ -60,7 +61,25 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
addedTokens[t.Content] = t addedTokens[t.Content] = t
} }
t.Merges = tt.Model.Merges if len(tt.Model.Merges) == 0 {
// noop; merges is empty
} else if err := json.Unmarshal(tt.Model.Merges, &t.Merges); err == nil {
// noop; merges is []string
} else if merges, err := func() ([][]string, error) {
var merges [][]string
if err := json.Unmarshal(tt.Model.Merges, &merges); err != nil {
return nil, err
}
return merges, nil
}(); err == nil {
t.Merges = make([]string, len(merges))
for i := range merges {
t.Merges[i] = strings.Join(merges[i], " ")
}
} else {
return nil, fmt.Errorf("could not parse tokenizer merges. expected []string or [][]string: %w", err)
}
sha256sum := sha256.New() sha256sum := sha256.New()
for _, pt := range tt.PreTokenizer.PreTokenizers { for _, pt := range tt.PreTokenizer.PreTokenizers {
@@ -158,7 +177,7 @@ type tokenizer struct {
Model struct { Model struct {
Type string `json:"type"` Type string `json:"type"`
Vocab map[string]int `json:"vocab"` Vocab map[string]int `json:"vocab"`
Merges []string `json:"merges"` Merges json.RawMessage `json:"merges"`
} `json:"model"` } `json:"model"`
PreTokenizer struct { PreTokenizer struct {

View File

@@ -191,6 +191,62 @@ func TestParseTokenizer(t *testing.T) {
Pre: "default", Pre: "default",
}, },
}, },
{
name: "list string merges",
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
"tokenizer.json": strings.NewReader(`{
"model": {
"merges": [
"a b",
"c d",
"e f"
]
}
}`),
}),
want: &Tokenizer{
Vocabulary: &Vocabulary{
Model: "gpt2",
},
Merges: []string{
"a b",
"c d",
"e f",
},
Pre: "default",
},
},
{
name: "list list string merges",
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
"tokenizer.json": strings.NewReader(`{
"model": {
"merges": [
[
"a", "b"
],
[
"c", "d"
],
[
"e", "f"
]
]
}
}`),
}),
want: &Tokenizer{
Vocabulary: &Vocabulary{
Model: "gpt2",
},
Merges: []string{
"a b",
"c d",
"e f",
},
Pre: "default",
},
},
} }
for _, tt := range cases { for _, tt := range cases {

View File

@@ -183,3 +183,17 @@ func (si SystemInfo) GetOptimalThreadCount() int {
return coreCount return coreCount
} }
// For each GPU, check if it does NOT support flash attention
func (l GpuInfoList) FlashAttentionSupported() bool {
for _, gpu := range l {
supportsFA := gpu.Library == "metal" ||
(gpu.Library == "cuda" && gpu.DriverMajor >= 7) ||
gpu.Library == "rocm"
if !supportsFA {
return false
}
}
return true
}

View File

@@ -49,10 +49,10 @@ Advanced parameters (optional):
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `system`: system message to (overrides what is defined in the `Modelfile`) - `system`: system message to (overrides what is defined in the `Modelfile`)
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`) - `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects - `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
- `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API - `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) - `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
- `context` (deprecated): the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
#### JSON mode #### JSON mode

View File

@@ -151,7 +151,7 @@ Refer to the section [above](#how-do-i-configure-ollama-server) for how to set e
Ollama runs an HTTP server and can be exposed using a proxy server such as Nginx. To do so, configure the proxy to forward requests and optionally set required headers (if not exposing Ollama on the network). For example, with Nginx: Ollama runs an HTTP server and can be exposed using a proxy server such as Nginx. To do so, configure the proxy to forward requests and optionally set required headers (if not exposing Ollama on the network). For example, with Nginx:
``` ```nginx
server { server {
listen 80; listen 80;
server_name example.com; # Replace with your domain or IP server_name example.com; # Replace with your domain or IP
@@ -285,4 +285,28 @@ Note: Windows with Radeon GPUs currently default to 1 model maximum due to limit
## How does Ollama load models on multiple GPUs? ## How does Ollama load models on multiple GPUs?
Installing multiple GPUs of the same brand can be a great way to increase your available VRAM to load larger models. When you load a new model, Ollama evaluates the required VRAM for the model against what is currently available. If the model will entirely fit on any single GPU, Ollama will load the model on that GPU. This typically provides the best performance as it reduces the amount of data transfering across the PCI bus during inference. If the model does not fit entirely on one GPU, then it will be spread across all the available GPUs. When loading a new model, Ollama evaluates the required VRAM for the model against what is currently available. If the model will entirely fit on any single GPU, Ollama will load the model on that GPU. This typically provides the best performance as it reduces the amount of data transferring across the PCI bus during inference. If the model does not fit entirely on one GPU, then it will be spread across all the available GPUs.
## How can I enable Flash Attention?
Flash Attention is a feature of most modern models that can significantly reduce memory usage as the context size grows. To enable Flash Attention, set the `OLLAMA_FLASH_ATTENTION` environment variable to `1` when starting the Ollama server.
## How can I set the quantization type for the K/V cache?
The K/V context cache can be quantized to significantly reduce memory usage when Flash Attention is enabled.
To use quantized K/V cache with Ollama you can set the following environment variable:
- `OLLAMA_KV_CACHE_TYPE` - The quantization type for the K/V cache. Default is `f16`.
> Note: Currently this is a global option - meaning all models will run with the specified quantization type.
The currently available K/V cache quantization types are:
- `f16` - high precision and memory usage (default).
- `q8_0` - 8-bit quantization, uses approximately 1/2 the memory of `f16` with a very small loss in precision, this usually has no noticeable impact on the model's quality (recommended if not using f16).
- `q4_0` - 4-bit quantization, uses approximately 1/4 the memory of `f16` with a small-medium loss in precision that may be more noticeable at higher context sizes.
How much the cache quantization impacts the model's response quality will depend on the model and the task. Models that have a high GQA count (e.g. Qwen2) may see a larger impact on precision from quantization than models with a low GQA count.
You may need to experiment with different quantization types to find the best balance between memory usage and quality.

View File

@@ -63,7 +63,7 @@ SYSTEM You are Mario from super mario bros, acting as an assistant.
To use this: To use this:
1. Save it as a file (e.g. `Modelfile`) 1. Save it as a file (e.g. `Modelfile`)
2. `ollama create choose-a-model-name -f <location of the file e.g. ./Modelfile>'` 2. `ollama create choose-a-model-name -f <location of the file e.g. ./Modelfile>`
3. `ollama run choose-a-model-name` 3. `ollama run choose-a-model-name`
4. Start using the model! 4. Start using the model!
@@ -156,7 +156,7 @@ PARAMETER <parameter> <parametervalue>
| seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | int | seed 42 | | seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | int | seed 42 |
| stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate `stop` parameters in a modelfile. | string | stop "AI assistant:" | | stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate `stop` parameters in a modelfile. | string | stop "AI assistant:" |
| tfs_z | Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1) | float | tfs_z 1 | | tfs_z | Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1) | float | tfs_z 1 |
| num_predict | Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context) | int | num_predict 42 | | num_predict | Maximum number of tokens to predict when generating text. (Default: -1, infinite generation) | int | num_predict 42 |
| top_k | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | top_k 40 | | top_k | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | top_k 40 |
| top_p | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | float | top_p 0.9 | | top_p | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | float | top_p 0.9 |
| min_p | Alternative to the top_p, and aims to ensure a balance of quality and variety. The parameter *p* represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with *p*=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out. (Default: 0.0) | float | min_p 0.05 | | min_p | Alternative to the top_p, and aims to ensure a balance of quality and variety. The parameter *p* represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with *p*=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out. (Default: 0.0) | float | min_p 0.05 |

View File

@@ -199,6 +199,8 @@ curl http://localhost:11434/v1/embeddings \
- [x] `seed` - [x] `seed`
- [x] `stop` - [x] `stop`
- [x] `stream` - [x] `stream`
- [x] `stream_options`
- [x] `include_usage`
- [x] `temperature` - [x] `temperature`
- [x] `top_p` - [x] `top_p`
- [x] `max_tokens` - [x] `max_tokens`
@@ -227,6 +229,8 @@ curl http://localhost:11434/v1/embeddings \
- [x] `seed` - [x] `seed`
- [x] `stop` - [x] `stop`
- [x] `stream` - [x] `stream`
- [x] `stream_options`
- [x] `include_usage`
- [x] `temperature` - [x] `temperature`
- [x] `top_p` - [x] `top_p`
- [x] `max_tokens` - [x] `max_tokens`

View File

@@ -153,6 +153,8 @@ var (
Debug = Bool("OLLAMA_DEBUG") Debug = Bool("OLLAMA_DEBUG")
// FlashAttention enables the experimental flash attention feature. // FlashAttention enables the experimental flash attention feature.
FlashAttention = Bool("OLLAMA_FLASH_ATTENTION") FlashAttention = Bool("OLLAMA_FLASH_ATTENTION")
// KvCacheType is the quantization type for the K/V cache.
KvCacheType = String("OLLAMA_KV_CACHE_TYPE")
// NoHistory disables readline history. // NoHistory disables readline history.
NoHistory = Bool("OLLAMA_NOHISTORY") NoHistory = Bool("OLLAMA_NOHISTORY")
// NoPrune disables pruning of model blobs on startup. // NoPrune disables pruning of model blobs on startup.
@@ -234,6 +236,7 @@ func AsMap() map[string]EnvVar {
ret := map[string]EnvVar{ ret := map[string]EnvVar{
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"}, "OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"}, "OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"},
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"}, "OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"}, "OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"}, "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},

View File

@@ -105,7 +105,7 @@ make apply-patches
**Pin to new base commit** **Pin to new base commit**
To update to a newer base commit, select the upstream git tag or commit and update `llama/vendoring.env` To update to a newer base commit, select the upstream git tag or commit and update `llama/vendoring`
#### Applying patches #### Applying patches

View File

@@ -85,9 +85,12 @@ COMPILER inline get_compiler() {
import "C" import "C"
import ( import (
"bytes"
_ "embed" _ "embed"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"log/slog"
"runtime" "runtime"
"runtime/cgo" "runtime/cgo"
"slices" "slices"
@@ -140,7 +143,7 @@ type ContextParams struct {
c C.struct_llama_context_params c C.struct_llama_context_params
} }
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool) ContextParams { func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams {
params := C.llama_context_default_params() params := C.llama_context_default_params()
params.n_ctx = C.uint(numCtx) params.n_ctx = C.uint(numCtx)
params.n_batch = C.uint(batchSize) params.n_batch = C.uint(batchSize)
@@ -149,9 +152,28 @@ func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, fla
params.n_threads_batch = params.n_threads params.n_threads_batch = params.n_threads
params.embeddings = C.bool(true) params.embeddings = C.bool(true)
params.flash_attn = C.bool(flashAttention) params.flash_attn = C.bool(flashAttention)
params.type_k = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
params.type_v = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
return ContextParams{c: params} return ContextParams{c: params}
} }
// kvCacheTypeFromStr converts a string cache type to the corresponding GGML type value
func kvCacheTypeFromStr(s string) C.enum_ggml_type {
if s == "" {
return C.GGML_TYPE_F16
}
switch s {
case "q8_0":
return C.GGML_TYPE_Q8_0
case "q4_0":
return C.GGML_TYPE_Q4_0
default:
return C.GGML_TYPE_F16
}
}
type Context struct { type Context struct {
c *C.struct_llama_context c *C.struct_llama_context
numThreads int numThreads int
@@ -680,3 +702,33 @@ func (s *SamplingContext) Sample(llamaContext *Context, idx int) int {
func (s *SamplingContext) Accept(id int, applyGrammar bool) { func (s *SamplingContext) Accept(id int, applyGrammar bool) {
C.gpt_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar)) C.gpt_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar))
} }
type JsonSchema struct {
Defs map[string]any `json:"$defs,omitempty"`
Properties map[string]any `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
Title string `json:"title,omitempty"`
Type string `json:"type,omitempty"`
}
func (js JsonSchema) AsGrammar() string {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(js); err != nil {
return ""
}
cStr := C.CString(b.String())
defer C.free(unsafe.Pointer(cStr))
// Allocate buffer for grammar output with reasonable size
const maxLen = 32768 // 32KB
buf := make([]byte, maxLen)
// Call C function to convert schema to grammar
length := C.schema_to_grammar(cStr, (*C.char)(unsafe.Pointer(&buf[0])), C.size_t(maxLen))
if length == 0 {
slog.Warn("unable to convert schema to grammar")
}
return string(buf[:length])
}

View File

@@ -1 +1,70 @@
package llama package llama
import (
"strings"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestJsonSchema(t *testing.T) {
testCases := []struct {
name string
schema JsonSchema
expected string
}{
{
name: "empty schema",
schema: JsonSchema{
Type: "object",
},
expected: `array ::= "[" space ( value ("," space value)* )? "]" space
boolean ::= ("true" | "false") space
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
null ::= "null" space
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
root ::= object
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
value ::= object | array | string | number | boolean | null`,
},
{
name: "invalid schema with circular reference",
schema: JsonSchema{
Type: "object",
Properties: map[string]any{
"self": map[string]any{
"$ref": "#", // Self reference
},
},
},
expected: "", // Should return empty string for invalid schema
},
{
name: "schema with invalid type",
schema: JsonSchema{
Type: "invalid_type", // Invalid type
Properties: map[string]any{
"foo": map[string]any{
"type": "string",
},
},
},
expected: "", // Should return empty string for invalid schema
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := tc.schema.AsGrammar()
if !strings.EqualFold(strings.TrimSpace(result), strings.TrimSpace(tc.expected)) {
if diff := cmp.Diff(tc.expected, result); diff != "" {
t.Fatalf("grammar mismatch (-want +got):\n%s", diff)
}
}
})
}
}

View File

@@ -850,6 +850,7 @@ func (s *Server) loadModel(
lpath multiLPath, lpath multiLPath,
ppath string, ppath string,
kvSize int, kvSize int,
kvCacheType string,
flashAttention bool, flashAttention bool,
threads int, threads int,
multiUserCache bool, multiUserCache bool,
@@ -862,7 +863,7 @@ func (s *Server) loadModel(
panic(err) panic(err)
} }
ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention) ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, kvCacheType)
s.lc, err = llama.NewContextWithModel(s.model, ctxParams) s.lc, err = llama.NewContextWithModel(s.model, ctxParams)
if err != nil { if err != nil {
panic(err) panic(err)
@@ -903,6 +904,7 @@ func main() {
mainGpu := flag.Int("main-gpu", 0, "Main GPU") mainGpu := flag.Int("main-gpu", 0, "Main GPU")
flashAttention := flag.Bool("flash-attn", false, "Enable flash attention") flashAttention := flag.Bool("flash-attn", false, "Enable flash attention")
kvSize := flag.Int("ctx-size", 2048, "Context (or KV cache) size") kvSize := flag.Int("ctx-size", 2048, "Context (or KV cache) size")
kvCacheType := flag.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
port := flag.Int("port", 8080, "Port to expose the server on") port := flag.Int("port", 8080, "Port to expose the server on")
threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation") threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
verbose := flag.Bool("verbose", false, "verbose output (default: disabled)") verbose := flag.Bool("verbose", false, "verbose output (default: disabled)")
@@ -970,7 +972,7 @@ func main() {
} }
server.ready.Add(1) server.ready.Add(1)
go server.loadModel(params, *mpath, lpaths, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache) go server.loadModel(params, *mpath, lpaths, *ppath, *kvSize, *kvCacheType, *flashAttention, *threads, *multiUserCache)
server.cond = sync.NewCond(&server.mu) server.cond = sync.NewCond(&server.mu)

View File

@@ -1,11 +1,13 @@
// TODO: this is a temporary wrapper to allow calling C++ code from CGo // TODO: this is a temporary wrapper to allow calling C++ code from CGo
#include "sampling.h" #include "sampling.h"
#include "sampling_ext.h" #include "sampling_ext.h"
#include "json-schema-to-grammar.h"
struct gpt_sampler *gpt_sampler_cinit( struct gpt_sampler *gpt_sampler_cinit(
const struct llama_model *model, struct gpt_sampler_cparams *params) const struct llama_model *model, struct gpt_sampler_cparams *params)
{ {
try { try
{
gpt_sampler_params sparams; gpt_sampler_params sparams;
sparams.top_k = params->top_k; sparams.top_k = params->top_k;
sparams.top_p = params->top_p; sparams.top_p = params->top_p;
@@ -24,7 +26,9 @@ struct gpt_sampler *gpt_sampler_cinit(
sparams.seed = params->seed; sparams.seed = params->seed;
sparams.grammar = params->grammar; sparams.grammar = params->grammar;
return gpt_sampler_init(model, sparams); return gpt_sampler_init(model, sparams);
} catch (const std::exception & err) { }
catch (const std::exception &err)
{
return nullptr; return nullptr;
} }
} }
@@ -54,3 +58,24 @@ void gpt_sampler_caccept(
{ {
gpt_sampler_accept(sampler, id, apply_grammar); gpt_sampler_accept(sampler, id, apply_grammar);
} }
int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len)
{
try
{
nlohmann::json schema = nlohmann::json::parse(json_schema);
std::string grammar_str = json_schema_to_grammar(schema);
size_t len = grammar_str.length();
if (len >= max_len)
{
len = max_len - 1;
}
strncpy(grammar, grammar_str.c_str(), len);
return len;
}
catch (const std::exception &e)
{
strncpy(grammar, "", max_len - 1);
return 0;
}
}

View File

@@ -47,6 +47,8 @@ extern "C"
llama_token id, llama_token id,
bool apply_grammar); bool apply_grammar);
int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@@ -360,7 +360,7 @@ func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
}, offset, nil }, offset, nil
} }
func (llm GGML) GraphSize(context, batch uint64) (kv, partialOffload, fullOffload uint64) { func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
embedding := llm.KV().EmbeddingLength() embedding := llm.KV().EmbeddingLength()
heads := llm.KV().HeadCount() heads := llm.KV().HeadCount()
headsKV := llm.KV().HeadCountKV() headsKV := llm.KV().HeadCountKV()
@@ -372,7 +372,8 @@ func (llm GGML) GraphSize(context, batch uint64) (kv, partialOffload, fullOffloa
layers := llm.Tensors().Layers() layers := llm.Tensors().Layers()
kv = 2 * context * llm.KV().BlockCount() * (embeddingHeadsK + embeddingHeadsV) * headsKV bytesPerElement := kvCacheBytesPerElement(kvCacheType)
kv = uint64(float64(context*llm.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
switch llm.KV().Architecture() { switch llm.KV().Architecture() {
case "llama": case "llama":
@@ -527,3 +528,34 @@ func (llm GGML) GraphSize(context, batch uint64) (kv, partialOffload, fullOffloa
return return
} }
// SupportsKVCacheType checks if the requested cache type is supported
func (ggml GGML) SupportsKVCacheType(cacheType string) bool {
validKVCacheTypes := []string{"f16", "q8_0", "q4_0"}
return slices.Contains(validKVCacheTypes, cacheType)
}
// SupportsFlashAttention checks if the model supports flash attention
func (ggml GGML) SupportsFlashAttention() bool {
_, isEmbedding := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]
if isEmbedding {
return false
}
// Check head counts match and are non-zero
headCountK := ggml.KV().EmbeddingHeadCountK()
headCountV := ggml.KV().EmbeddingHeadCountV()
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
}
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
func kvCacheBytesPerElement(cacheType string) float64 {
switch cacheType {
case "q8_0":
return 1 // 1/2 of fp16
case "q4_0":
return 0.5 // 1/4 of fp16
default:
return 2 // f16 (default)
}
}

View File

@@ -123,7 +123,23 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string,
slog.Warn("model missing blk.0 layer size") slog.Warn("model missing blk.0 layer size")
} }
kv, graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch))) fa := envconfig.FlashAttention() &&
discover.GetGPUInfo().FlashAttentionSupported() &&
ggml.SupportsFlashAttention()
var kvct string
if fa {
requested := strings.ToLower(envconfig.KvCacheType())
if requested != "" && ggml.SupportsKVCacheType(requested) {
kvct = requested
}
}
kv, graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), kvct)
// KV is proportional to the number of layers
layerSize += kv / ggml.KV().BlockCount()
if graphPartialOffload == 0 { if graphPartialOffload == 0 {
graphPartialOffload = ggml.KV().GQA() * kv / 6 graphPartialOffload = ggml.KV().GQA() * kv / 6
} }
@@ -131,9 +147,6 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string,
graphFullOffload = graphPartialOffload graphFullOffload = graphPartialOffload
} }
// KV is proportional to the number of layers
layerSize += kv / ggml.KV().BlockCount()
// on metal there's no partial offload overhead // on metal there's no partial offload overhead
if gpus[0].Library == "metal" { if gpus[0].Library == "metal" {
graphPartialOffload = graphFullOffload graphPartialOffload = graphFullOffload

View File

@@ -15,6 +15,7 @@ import (
func TestEstimateGPULayers(t *testing.T) { func TestEstimateGPULayers(t *testing.T) {
t.Setenv("OLLAMA_DEBUG", "1") t.Setenv("OLLAMA_DEBUG", "1")
t.Setenv("OLLAMA_KV_CACHE_TYPE", "") // Ensure default f16
modelName := "dummy" modelName := "dummy"
f, err := os.CreateTemp(t.TempDir(), modelName) f, err := os.CreateTemp(t.TempDir(), modelName)

View File

@@ -214,15 +214,36 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
params = append(params, "--threads", strconv.Itoa(defaultThreads)) params = append(params, "--threads", strconv.Itoa(defaultThreads))
} }
flashAttnEnabled := envconfig.FlashAttention() fa := envconfig.FlashAttention()
if fa && !gpus.FlashAttentionSupported() {
slog.Warn("flash attention enabled but not supported by gpu")
fa = false
}
for _, g := range gpus { if fa && !ggml.SupportsFlashAttention() {
// only cuda (compute capability 7+) and metal support flash attention slog.Warn("flash attention enabled but not supported by model")
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) { fa = false
flashAttnEnabled = false }
kvct := strings.ToLower(envconfig.KvCacheType())
if fa {
slog.Info("enabling flash attention")
params = append(params, "--flash-attn")
// Flash Attention also supports kv cache quantization
// Enable if the requested and kv cache type is supported by the model
if kvct != "" && ggml.SupportsKVCacheType(kvct) {
params = append(params, "--kv-cache-type", kvct)
} else {
slog.Warn("kv cache type not supported by model", "type", kvct)
}
} else if kvct != "" && kvct != "f16" {
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
} }
// mmap has issues with partial offloading on metal // mmap has issues with partial offloading on metal
for _, g := range gpus {
if g.Library == "metal" && if g.Library == "metal" &&
uint64(opts.NumGPU) > 0 && uint64(opts.NumGPU) > 0 &&
uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 { uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
@@ -231,10 +252,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
} }
} }
if flashAttnEnabled {
params = append(params, "--flash-attn")
}
// Windows CUDA should not use mmap for best performance // Windows CUDA should not use mmap for best performance
// Linux with a model larger than free space, mmap leads to thrashing // Linux with a model larger than free space, mmap leads to thrashing
// For CPU loads we want the memory to be allocated, not FS cache // For CPU loads we want the memory to be allocated, not FS cache
@@ -617,27 +634,22 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
const jsonGrammar = ` const jsonGrammar = `
root ::= object root ::= object
value ::= object | array | string | number | ("true" | "false" | "null") ws value ::= object | array | string | number | ("true" | "false" | "null") ws
object ::= object ::=
"{" ws ( "{" ws (
string ":" ws value string ":" ws value
("," ws string ":" ws value)* ("," ws string ":" ws value)*
)? "}" ws )? "}" ws
array ::= array ::=
"[" ws ( "[" ws (
value value
("," ws value)* ("," ws value)*
)? "]" ws )? "]" ws
string ::= string ::=
"\"" ( "\"" (
[^"\\\x7F\x00-\x1F] | [^"\\\x7F\x00-\x1F] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws )* "\"" ws
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
# Optional space: by convention, applied in this grammar after literal chars when allowed # Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)? ws ::= ([ \t\n] ws)?
` `
@@ -667,7 +679,7 @@ type completion struct {
type CompletionRequest struct { type CompletionRequest struct {
Prompt string Prompt string
Format string Format json.RawMessage
Images []ImageData Images []ImageData
Options *api.Options Options *api.Options
} }
@@ -732,10 +744,22 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return fmt.Errorf("unexpected server status: %s", status.ToString()) return fmt.Errorf("unexpected server status: %s", status.ToString())
} }
if req.Format == "json" { // TODO (parthsareen): Move conversion to grammar with sampling logic
// API should do error handling for invalid formats
if req.Format != nil {
if strings.ToLower(strings.TrimSpace(string(req.Format))) == `"json"` {
request["grammar"] = jsonGrammar request["grammar"] = jsonGrammar
if !strings.Contains(strings.ToLower(req.Prompt), "json") { if !strings.Contains(strings.ToLower(req.Prompt), "json") {
slog.Warn("Prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.") slog.Warn("prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
}
} else if schema, err := func() (llama.JsonSchema, error) {
var schema llama.JsonSchema
err := json.Unmarshal(req.Format, &schema)
return schema, err
}(); err == nil {
request["grammar"] = schema.AsGrammar()
} else {
slog.Warn(`format is neither a schema or "json"`, "format", req.Format)
} }
} }

View File

@@ -63,6 +63,11 @@ type Usage struct {
type ResponseFormat struct { type ResponseFormat struct {
Type string `json:"type"` Type string `json:"type"`
JsonSchema *JsonSchema `json:"json_schema,omitempty"`
}
type JsonSchema struct {
Schema map[string]any `json:"schema"`
} }
type EmbedRequest struct { type EmbedRequest struct {
@@ -70,10 +75,15 @@ type EmbedRequest struct {
Model string `json:"model"` Model string `json:"model"`
} }
type StreamOptions struct {
IncludeUsage bool `json:"include_usage"`
}
type ChatCompletionRequest struct { type ChatCompletionRequest struct {
Model string `json:"model"` Model string `json:"model"`
Messages []Message `json:"messages"` Messages []Message `json:"messages"`
Stream bool `json:"stream"` Stream bool `json:"stream"`
StreamOptions *StreamOptions `json:"stream_options"`
MaxTokens *int `json:"max_tokens"` MaxTokens *int `json:"max_tokens"`
Seed *int `json:"seed"` Seed *int `json:"seed"`
Stop any `json:"stop"` Stop any `json:"stop"`
@@ -102,6 +112,7 @@ type ChatCompletionChunk struct {
Model string `json:"model"` Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"` SystemFingerprint string `json:"system_fingerprint"`
Choices []ChunkChoice `json:"choices"` Choices []ChunkChoice `json:"choices"`
Usage *Usage `json:"usage,omitempty"`
} }
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int // TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
@@ -114,6 +125,7 @@ type CompletionRequest struct {
Seed *int `json:"seed"` Seed *int `json:"seed"`
Stop any `json:"stop"` Stop any `json:"stop"`
Stream bool `json:"stream"` Stream bool `json:"stream"`
StreamOptions *StreamOptions `json:"stream_options"`
Temperature *float32 `json:"temperature"` Temperature *float32 `json:"temperature"`
TopP float32 `json:"top_p"` TopP float32 `json:"top_p"`
Suffix string `json:"suffix"` Suffix string `json:"suffix"`
@@ -136,10 +148,12 @@ type CompletionChunk struct {
Choices []CompleteChunkChoice `json:"choices"` Choices []CompleteChunkChoice `json:"choices"`
Model string `json:"model"` Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"` SystemFingerprint string `json:"system_fingerprint"`
Usage *Usage `json:"usage,omitempty"`
} }
type ToolCall struct { type ToolCall struct {
ID string `json:"id"` ID string `json:"id"`
Index int `json:"index"`
Type string `json:"type"` Type string `json:"type"`
Function struct { Function struct {
Name string `json:"name"` Name string `json:"name"`
@@ -191,6 +205,14 @@ func NewError(code int, message string) ErrorResponse {
return ErrorResponse{Error{Type: etype, Message: message}} return ErrorResponse{Error{Type: etype, Message: message}}
} }
func toUsage(r api.ChatResponse) Usage {
return Usage{
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
}
}
func toolCallId() string { func toolCallId() string {
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789" const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, 8) b := make([]byte, 8)
@@ -206,6 +228,7 @@ func toToolCalls(tc []api.ToolCall) []ToolCall {
toolCalls[i].ID = toolCallId() toolCalls[i].ID = toolCallId()
toolCalls[i].Type = "function" toolCalls[i].Type = "function"
toolCalls[i].Function.Name = tc.Function.Name toolCalls[i].Function.Name = tc.Function.Name
toolCalls[i].Index = tc.Function.Index
args, err := json.Marshal(tc.Function.Arguments) args, err := json.Marshal(tc.Function.Arguments)
if err != nil { if err != nil {
@@ -239,11 +262,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
return nil return nil
}(r.DoneReason), }(r.DoneReason),
}}, }},
Usage: Usage{ Usage: toUsage(r),
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
},
} }
} }
@@ -268,6 +287,14 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
} }
} }
func toUsageGenerate(r api.GenerateResponse) Usage {
return Usage{
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
}
}
func toCompletion(id string, r api.GenerateResponse) Completion { func toCompletion(id string, r api.GenerateResponse) Completion {
return Completion{ return Completion{
Id: id, Id: id,
@@ -285,11 +312,7 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
return nil return nil
}(r.DoneReason), }(r.DoneReason),
}}, }},
Usage: Usage{ Usage: toUsageGenerate(r),
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
},
} }
} }
@@ -480,9 +503,21 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
options["top_p"] = 1.0 options["top_p"] = 1.0
} }
var format string var format json.RawMessage
if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" { if r.ResponseFormat != nil {
format = "json" switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
// Support the old "json_object" type for OpenAI compatibility
case "json_object":
format = json.RawMessage(`"json"`)
case "json_schema":
if r.ResponseFormat.JsonSchema != nil {
schema, err := json.Marshal(r.ResponseFormat.JsonSchema.Schema)
if err != nil {
return nil, fmt.Errorf("failed to marshal json schema: %w", err)
}
format = schema
}
}
} }
return &api.ChatRequest{ return &api.ChatRequest{
@@ -552,12 +587,14 @@ type BaseWriter struct {
type ChatWriter struct { type ChatWriter struct {
stream bool stream bool
streamOptions *StreamOptions
id string id string
BaseWriter BaseWriter
} }
type CompleteWriter struct { type CompleteWriter struct {
stream bool stream bool
streamOptions *StreamOptions
id string id string
BaseWriter BaseWriter
} }
@@ -601,7 +638,11 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
// chat chunk // chat chunk
if w.stream { if w.stream {
d, err := json.Marshal(toChunk(w.id, chatResponse)) c := toChunk(w.id, chatResponse)
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
c.Usage = &Usage{}
}
d, err := json.Marshal(c)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -613,6 +654,17 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
} }
if chatResponse.Done { if chatResponse.Done {
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
u := toUsage(chatResponse)
d, err := json.Marshal(ChatCompletionChunk{Choices: []ChunkChoice{}, Usage: &u})
if err != nil {
return 0, err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
}
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n")) _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil { if err != nil {
return 0, err return 0, err
@@ -650,7 +702,11 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
// completion chunk // completion chunk
if w.stream { if w.stream {
d, err := json.Marshal(toCompleteChunk(w.id, generateResponse)) c := toCompleteChunk(w.id, generateResponse)
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
c.Usage = &Usage{}
}
d, err := json.Marshal(c)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -662,6 +718,17 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
} }
if generateResponse.Done { if generateResponse.Done {
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
u := toUsageGenerate(generateResponse)
d, err := json.Marshal(CompletionChunk{Choices: []CompleteChunkChoice{}, Usage: &u})
if err != nil {
return 0, err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
}
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n")) _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil { if err != nil {
return 0, err return 0, err
@@ -827,6 +894,7 @@ func CompletionsMiddleware() gin.HandlerFunc {
BaseWriter: BaseWriter{ResponseWriter: c.Writer}, BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream, stream: req.Stream,
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)), id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
streamOptions: req.StreamOptions,
} }
c.Writer = w c.Writer = w
@@ -909,6 +977,7 @@ func ChatMiddleware() gin.HandlerFunc {
BaseWriter: BaseWriter{ResponseWriter: c.Writer}, BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream, stream: req.Stream,
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)), id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
streamOptions: req.StreamOptions,
} }
c.Writer = w c.Writer = w

View File

@@ -13,6 +13,7 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
@@ -107,7 +108,46 @@ func TestChatMiddleware(t *testing.T) {
"presence_penalty": 5.0, "presence_penalty": 5.0,
"top_p": 6.0, "top_p": 6.0,
}, },
Format: "json", Format: json.RawMessage(`"json"`),
Stream: &True,
},
},
{
name: "chat handler with streaming usage",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
],
"stream": true,
"stream_options": {"include_usage": true},
"max_tokens": 999,
"seed": 123,
"stop": ["\n", "stop"],
"temperature": 3.0,
"frequency_penalty": 4.0,
"presence_penalty": 5.0,
"top_p": 6.0,
"response_format": {"type": "json_object"}
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "Hello",
},
},
Options: map[string]any{
"num_predict": 999.0, // float because JSON doesn't distinguish between float and int
"seed": 123.0,
"stop": []any{"\n", "stop"},
"temperature": 3.0,
"frequency_penalty": 4.0,
"presence_penalty": 5.0,
"top_p": 6.0,
},
Format: json.RawMessage(`"json"`),
Stream: &True, Stream: &True,
}, },
}, },
@@ -195,7 +235,86 @@ func TestChatMiddleware(t *testing.T) {
Stream: &False, Stream: &False,
}, },
}, },
{
name: "chat handler with streaming tools",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "What's the weather like in Paris?"}
],
"stream": true,
"tools": [{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"required": ["location"],
"properties": {
"location": {
"type": "string",
"description": "The city and state"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
}
}
}
}]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "What's the weather like in Paris?",
},
},
Tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather",
Parameters: struct {
Type string `json:"type"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
}{
Type: "object",
Required: []string{"location"},
Properties: map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}{
"location": {
Type: "string",
Description: "The city and state",
},
"unit": {
Type: "string",
Enum: []string{"celsius", "fahrenheit"},
},
},
},
},
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &True,
},
},
{ {
name: "chat handler error forwarding", name: "chat handler error forwarding",
body: `{ body: `{
@@ -237,13 +356,13 @@ func TestChatMiddleware(t *testing.T) {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err) t.Fatal(err)
} }
return
} }
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
t.Fatal("requests did not match") t.Fatalf("requests did not match: %+v", diff)
} }
if diff := cmp.Diff(tc.err, errResp); diff != "" {
if !reflect.DeepEqual(tc.err, errResp) { t.Fatalf("errors did not match for %s:\n%s", tc.name, diff)
t.Fatal("errors did not match")
} }
}) })
} }
@@ -283,6 +402,55 @@ func TestCompletionsMiddleware(t *testing.T) {
Stream: &False, Stream: &False,
}, },
}, },
{
name: "completions handler stream",
body: `{
"model": "test-model",
"prompt": "Hello",
"stream": true,
"temperature": 0.8,
"stop": ["\n", "stop"],
"suffix": "suffix"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "Hello",
Options: map[string]any{
"frequency_penalty": 0.0,
"presence_penalty": 0.0,
"temperature": 0.8,
"top_p": 1.0,
"stop": []any{"\n", "stop"},
},
Suffix: "suffix",
Stream: &True,
},
},
{
name: "completions handler stream with usage",
body: `{
"model": "test-model",
"prompt": "Hello",
"stream": true,
"stream_options": {"include_usage": true},
"temperature": 0.8,
"stop": ["\n", "stop"],
"suffix": "suffix"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "Hello",
Options: map[string]any{
"frequency_penalty": 0.0,
"presence_penalty": 0.0,
"temperature": 0.8,
"top_p": 1.0,
"stop": []any{"\n", "stop"},
},
Suffix: "suffix",
Stream: &True,
},
},
{ {
name: "completions handler error forwarding", name: "completions handler error forwarding",
body: `{ body: `{

View File

@@ -148,10 +148,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
if req.Format != "" && req.Format != "json" { if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
return
} else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
return return
} }
@@ -251,6 +248,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var b bytes.Buffer var b bytes.Buffer
if req.Context != nil { if req.Context != nil {
slog.Warn("the context field is deprecated and will be removed in a future version of Ollama")
s, err := r.Detokenize(c.Request.Context(), req.Context) s, err := r.Detokenize(c.Request.Context(), req.Context)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -1469,7 +1467,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
go func() { go func() {
defer close(ch) defer close(ch)
var sb strings.Builder var sb strings.Builder
var hasToolCalls bool var toolCallIndex int = 0
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Images: images, Images: images,
@@ -1509,16 +1507,19 @@ func (s *Server) ChatHandler(c *gin.Context) {
sb.WriteString(r.Content) sb.WriteString(r.Content)
if toolCalls, ok := m.parseToolCalls(sb.String()); ok { if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
res.Message.ToolCalls = toolCalls res.Message.ToolCalls = toolCalls
for i := range toolCalls {
toolCalls[i].Function.Index = toolCallIndex
toolCallIndex++
}
res.Message.Content = "" res.Message.Content = ""
sb.Reset() sb.Reset()
hasToolCalls = true
ch <- res ch <- res
return return
} }
if r.Done { if r.Done {
// Send any remaining content if no tool calls were detected // Send any remaining content if no tool calls were detected
if !hasToolCalls { if toolCallIndex == 0 {
res.Message.Content = sb.String() res.Message.Content = sb.String()
} }
ch <- res ch <- res