Compare commits
21 Commits
v0.4.6
...
parth/open
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2536ffe0ab | ||
|
|
97abd7bfea | ||
|
|
c6509bf76e | ||
|
|
aed1419c64 | ||
|
|
c6c526275d | ||
|
|
630e7dc6ff | ||
|
|
eb8366d658 | ||
|
|
4456012956 | ||
|
|
539be43640 | ||
|
|
1bdab9fdb1 | ||
|
|
2b82c5a8a1 | ||
|
|
55c3efa900 | ||
|
|
1aedffad93 | ||
|
|
ff6c2d6dc8 | ||
|
|
d543b282a7 | ||
|
|
5f8051180e | ||
|
|
39e29ae5dd | ||
|
|
30a9f063c9 | ||
|
|
7355ab3703 | ||
|
|
7ed81437fe | ||
|
|
220108d3f4 |
3
.github/workflows/test.yaml
vendored
3
.github/workflows/test.yaml
vendored
@@ -310,8 +310,7 @@ jobs:
|
|||||||
arm64) echo ARCH=arm64 ;;
|
arm64) echo ARCH=arm64 ;;
|
||||||
esac >>$GITHUB_ENV
|
esac >>$GITHUB_ENV
|
||||||
shell: bash
|
shell: bash
|
||||||
- run: go build
|
- run: go test ./...
|
||||||
- run: go test -v ./...
|
|
||||||
|
|
||||||
patches:
|
patches:
|
||||||
needs: [changes]
|
needs: [changes]
|
||||||
|
|||||||
@@ -346,6 +346,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
|
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
|
||||||
- [Promptery](https://github.com/promptery/promptery) (desktop client for Ollama.)
|
- [Promptery](https://github.com/promptery/promptery) (desktop client for Ollama.)
|
||||||
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
||||||
|
- [SpaceLlama](https://github.com/tcsenpai/spacellama) (Firefox and Chrome extension to quickly summarize web pages with ollama in a sidebar)
|
||||||
|
- [YouLama](https://github.com/tcsenpai/youlama) (Webapp to quickly summarize any YouTube video, supporting Invidious as well)
|
||||||
|
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
|
||||||
- [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol)
|
- [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol)
|
||||||
- [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app)
|
- [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app)
|
||||||
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard and said in the meetings)
|
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard and said in the meetings)
|
||||||
@@ -356,6 +359,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Nosia](https://github.com/nosia-ai/nosia) (Easy to install and use RAG platform based on Ollama)
|
- [Nosia](https://github.com/nosia-ai/nosia) (Easy to install and use RAG platform based on Ollama)
|
||||||
- [Witsy](https://github.com/nbonamy/witsy) (An AI Desktop application avaiable for Mac/Windows/Linux)
|
- [Witsy](https://github.com/nbonamy/witsy) (An AI Desktop application avaiable for Mac/Windows/Linux)
|
||||||
- [Abbey](https://github.com/US-Artificial-Intelligence/abbey) (A configurable AI interface server with notebooks, document storage, and YouTube support)
|
- [Abbey](https://github.com/US-Artificial-Intelligence/abbey) (A configurable AI interface server with notebooks, document storage, and YouTube support)
|
||||||
|
- [Minima](https://github.com/dmayboroda/minima) (RAG with on-premises or fully local workflow)
|
||||||
|
|
||||||
### Cloud
|
### Cloud
|
||||||
|
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ type GenerateRequest struct {
|
|||||||
Raw bool `json:"raw,omitempty"`
|
Raw bool `json:"raw,omitempty"`
|
||||||
|
|
||||||
// Format specifies the format to return a response in.
|
// Format specifies the format to return a response in.
|
||||||
Format string `json:"format"`
|
Format json.RawMessage `json:"format,omitempty"`
|
||||||
|
|
||||||
// KeepAlive controls how long the model will stay loaded in memory following
|
// KeepAlive controls how long the model will stay loaded in memory following
|
||||||
// this request.
|
// this request.
|
||||||
@@ -94,7 +94,7 @@ type ChatRequest struct {
|
|||||||
Stream *bool `json:"stream,omitempty"`
|
Stream *bool `json:"stream,omitempty"`
|
||||||
|
|
||||||
// Format is the format to return the response in (e.g. "json").
|
// Format is the format to return the response in (e.g. "json").
|
||||||
Format string `json:"format"`
|
Format json.RawMessage `json:"format,omitempty"`
|
||||||
|
|
||||||
// KeepAlive controls how long the model will stay loaded into memory
|
// KeepAlive controls how long the model will stay loaded into memory
|
||||||
// following the request.
|
// following the request.
|
||||||
@@ -146,6 +146,7 @@ type ToolCall struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ToolCallFunction struct {
|
type ToolCallFunction struct {
|
||||||
|
Index int `json:"index,omitempty"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Arguments ToolCallFunctionArguments `json:"arguments"`
|
Arguments ToolCallFunctionArguments `json:"arguments"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -1038,7 +1039,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
|||||||
req := &api.ChatRequest{
|
req := &api.ChatRequest{
|
||||||
Model: opts.Model,
|
Model: opts.Model,
|
||||||
Messages: opts.Messages,
|
Messages: opts.Messages,
|
||||||
Format: opts.Format,
|
Format: json.RawMessage(opts.Format),
|
||||||
Options: opts.Options,
|
Options: opts.Options,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1125,7 +1126,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
|||||||
Prompt: opts.Prompt,
|
Prompt: opts.Prompt,
|
||||||
Context: generateContext,
|
Context: generateContext,
|
||||||
Images: opts.Images,
|
Images: opts.Images,
|
||||||
Format: opts.Format,
|
Format: json.RawMessage(opts.Format),
|
||||||
System: opts.System,
|
System: opts.System,
|
||||||
Options: opts.Options,
|
Options: opts.Options,
|
||||||
KeepAlive: opts.KeepAlive,
|
KeepAlive: opts.KeepAlive,
|
||||||
@@ -1445,6 +1446,7 @@ func NewCLI() *cobra.Command {
|
|||||||
envVars["OLLAMA_SCHED_SPREAD"],
|
envVars["OLLAMA_SCHED_SPREAD"],
|
||||||
envVars["OLLAMA_TMPDIR"],
|
envVars["OLLAMA_TMPDIR"],
|
||||||
envVars["OLLAMA_FLASH_ATTENTION"],
|
envVars["OLLAMA_FLASH_ATTENTION"],
|
||||||
|
envVars["OLLAMA_KV_CACHE_TYPE"],
|
||||||
envVars["OLLAMA_LLM_LIBRARY"],
|
envVars["OLLAMA_LLM_LIBRARY"],
|
||||||
envVars["OLLAMA_GPU_OVERHEAD"],
|
envVars["OLLAMA_GPU_OVERHEAD"],
|
||||||
envVars["OLLAMA_LOAD_TIMEOUT"],
|
envVars["OLLAMA_LOAD_TIMEOUT"],
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -180,18 +179,14 @@ Weigh anchor!
|
|||||||
|
|
||||||
t.Run("license", func(t *testing.T) {
|
t.Run("license", func(t *testing.T) {
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
license, err := os.ReadFile(filepath.Join("..", "LICENSE"))
|
license := "MIT License\nCopyright (c) Ollama\n"
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := showInfo(&api.ShowResponse{
|
if err := showInfo(&api.ShowResponse{
|
||||||
Details: api.ModelDetails{
|
Details: api.ModelDetails{
|
||||||
Family: "test",
|
Family: "test",
|
||||||
ParameterSize: "7B",
|
ParameterSize: "7B",
|
||||||
QuantizationLevel: "FP16",
|
QuantizationLevel: "FP16",
|
||||||
},
|
},
|
||||||
License: string(license),
|
License: license,
|
||||||
}, &b); err != nil {
|
}, &b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
)
|
)
|
||||||
@@ -60,7 +61,25 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
|
|||||||
addedTokens[t.Content] = t
|
addedTokens[t.Content] = t
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Merges = tt.Model.Merges
|
if len(tt.Model.Merges) == 0 {
|
||||||
|
// noop; merges is empty
|
||||||
|
} else if err := json.Unmarshal(tt.Model.Merges, &t.Merges); err == nil {
|
||||||
|
// noop; merges is []string
|
||||||
|
} else if merges, err := func() ([][]string, error) {
|
||||||
|
var merges [][]string
|
||||||
|
if err := json.Unmarshal(tt.Model.Merges, &merges); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return merges, nil
|
||||||
|
}(); err == nil {
|
||||||
|
t.Merges = make([]string, len(merges))
|
||||||
|
for i := range merges {
|
||||||
|
t.Merges[i] = strings.Join(merges[i], " ")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("could not parse tokenizer merges. expected []string or [][]string: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
sha256sum := sha256.New()
|
sha256sum := sha256.New()
|
||||||
for _, pt := range tt.PreTokenizer.PreTokenizers {
|
for _, pt := range tt.PreTokenizer.PreTokenizers {
|
||||||
@@ -158,7 +177,7 @@ type tokenizer struct {
|
|||||||
Model struct {
|
Model struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Vocab map[string]int `json:"vocab"`
|
Vocab map[string]int `json:"vocab"`
|
||||||
Merges []string `json:"merges"`
|
Merges json.RawMessage `json:"merges"`
|
||||||
} `json:"model"`
|
} `json:"model"`
|
||||||
|
|
||||||
PreTokenizer struct {
|
PreTokenizer struct {
|
||||||
|
|||||||
@@ -191,6 +191,62 @@ func TestParseTokenizer(t *testing.T) {
|
|||||||
Pre: "default",
|
Pre: "default",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "list string merges",
|
||||||
|
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||||
|
"tokenizer.json": strings.NewReader(`{
|
||||||
|
"model": {
|
||||||
|
"merges": [
|
||||||
|
"a b",
|
||||||
|
"c d",
|
||||||
|
"e f"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`),
|
||||||
|
}),
|
||||||
|
want: &Tokenizer{
|
||||||
|
Vocabulary: &Vocabulary{
|
||||||
|
Model: "gpt2",
|
||||||
|
},
|
||||||
|
Merges: []string{
|
||||||
|
"a b",
|
||||||
|
"c d",
|
||||||
|
"e f",
|
||||||
|
},
|
||||||
|
Pre: "default",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "list list string merges",
|
||||||
|
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||||
|
"tokenizer.json": strings.NewReader(`{
|
||||||
|
"model": {
|
||||||
|
"merges": [
|
||||||
|
[
|
||||||
|
"a", "b"
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"c", "d"
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"e", "f"
|
||||||
|
]
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}`),
|
||||||
|
}),
|
||||||
|
want: &Tokenizer{
|
||||||
|
Vocabulary: &Vocabulary{
|
||||||
|
Model: "gpt2",
|
||||||
|
},
|
||||||
|
Merges: []string{
|
||||||
|
"a b",
|
||||||
|
"c d",
|
||||||
|
"e f",
|
||||||
|
},
|
||||||
|
Pre: "default",
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range cases {
|
for _, tt := range cases {
|
||||||
|
|||||||
@@ -183,3 +183,17 @@ func (si SystemInfo) GetOptimalThreadCount() int {
|
|||||||
|
|
||||||
return coreCount
|
return coreCount
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For each GPU, check if it does NOT support flash attention
|
||||||
|
func (l GpuInfoList) FlashAttentionSupported() bool {
|
||||||
|
for _, gpu := range l {
|
||||||
|
supportsFA := gpu.Library == "metal" ||
|
||||||
|
(gpu.Library == "cuda" && gpu.DriverMajor >= 7) ||
|
||||||
|
gpu.Library == "rocm"
|
||||||
|
|
||||||
|
if !supportsFA {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|||||||
@@ -49,10 +49,10 @@ Advanced parameters (optional):
|
|||||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||||
- `system`: system message to (overrides what is defined in the `Modelfile`)
|
- `system`: system message to (overrides what is defined in the `Modelfile`)
|
||||||
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
|
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
|
||||||
- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
|
|
||||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||||
- `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API
|
- `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API
|
||||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||||
|
- `context` (deprecated): the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
|
||||||
|
|
||||||
#### JSON mode
|
#### JSON mode
|
||||||
|
|
||||||
|
|||||||
28
docs/faq.md
28
docs/faq.md
@@ -151,7 +151,7 @@ Refer to the section [above](#how-do-i-configure-ollama-server) for how to set e
|
|||||||
|
|
||||||
Ollama runs an HTTP server and can be exposed using a proxy server such as Nginx. To do so, configure the proxy to forward requests and optionally set required headers (if not exposing Ollama on the network). For example, with Nginx:
|
Ollama runs an HTTP server and can be exposed using a proxy server such as Nginx. To do so, configure the proxy to forward requests and optionally set required headers (if not exposing Ollama on the network). For example, with Nginx:
|
||||||
|
|
||||||
```
|
```nginx
|
||||||
server {
|
server {
|
||||||
listen 80;
|
listen 80;
|
||||||
server_name example.com; # Replace with your domain or IP
|
server_name example.com; # Replace with your domain or IP
|
||||||
@@ -285,4 +285,28 @@ Note: Windows with Radeon GPUs currently default to 1 model maximum due to limit
|
|||||||
|
|
||||||
## How does Ollama load models on multiple GPUs?
|
## How does Ollama load models on multiple GPUs?
|
||||||
|
|
||||||
Installing multiple GPUs of the same brand can be a great way to increase your available VRAM to load larger models. When you load a new model, Ollama evaluates the required VRAM for the model against what is currently available. If the model will entirely fit on any single GPU, Ollama will load the model on that GPU. This typically provides the best performance as it reduces the amount of data transfering across the PCI bus during inference. If the model does not fit entirely on one GPU, then it will be spread across all the available GPUs.
|
When loading a new model, Ollama evaluates the required VRAM for the model against what is currently available. If the model will entirely fit on any single GPU, Ollama will load the model on that GPU. This typically provides the best performance as it reduces the amount of data transferring across the PCI bus during inference. If the model does not fit entirely on one GPU, then it will be spread across all the available GPUs.
|
||||||
|
|
||||||
|
## How can I enable Flash Attention?
|
||||||
|
|
||||||
|
Flash Attention is a feature of most modern models that can significantly reduce memory usage as the context size grows. To enable Flash Attention, set the `OLLAMA_FLASH_ATTENTION` environment variable to `1` when starting the Ollama server.
|
||||||
|
|
||||||
|
## How can I set the quantization type for the K/V cache?
|
||||||
|
|
||||||
|
The K/V context cache can be quantized to significantly reduce memory usage when Flash Attention is enabled.
|
||||||
|
|
||||||
|
To use quantized K/V cache with Ollama you can set the following environment variable:
|
||||||
|
|
||||||
|
- `OLLAMA_KV_CACHE_TYPE` - The quantization type for the K/V cache. Default is `f16`.
|
||||||
|
|
||||||
|
> Note: Currently this is a global option - meaning all models will run with the specified quantization type.
|
||||||
|
|
||||||
|
The currently available K/V cache quantization types are:
|
||||||
|
|
||||||
|
- `f16` - high precision and memory usage (default).
|
||||||
|
- `q8_0` - 8-bit quantization, uses approximately 1/2 the memory of `f16` with a very small loss in precision, this usually has no noticeable impact on the model's quality (recommended if not using f16).
|
||||||
|
- `q4_0` - 4-bit quantization, uses approximately 1/4 the memory of `f16` with a small-medium loss in precision that may be more noticeable at higher context sizes.
|
||||||
|
|
||||||
|
How much the cache quantization impacts the model's response quality will depend on the model and the task. Models that have a high GQA count (e.g. Qwen2) may see a larger impact on precision from quantization than models with a low GQA count.
|
||||||
|
|
||||||
|
You may need to experiment with different quantization types to find the best balance between memory usage and quality.
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ SYSTEM You are Mario from super mario bros, acting as an assistant.
|
|||||||
To use this:
|
To use this:
|
||||||
|
|
||||||
1. Save it as a file (e.g. `Modelfile`)
|
1. Save it as a file (e.g. `Modelfile`)
|
||||||
2. `ollama create choose-a-model-name -f <location of the file e.g. ./Modelfile>'`
|
2. `ollama create choose-a-model-name -f <location of the file e.g. ./Modelfile>`
|
||||||
3. `ollama run choose-a-model-name`
|
3. `ollama run choose-a-model-name`
|
||||||
4. Start using the model!
|
4. Start using the model!
|
||||||
|
|
||||||
@@ -156,7 +156,7 @@ PARAMETER <parameter> <parametervalue>
|
|||||||
| seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | int | seed 42 |
|
| seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | int | seed 42 |
|
||||||
| stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate `stop` parameters in a modelfile. | string | stop "AI assistant:" |
|
| stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate `stop` parameters in a modelfile. | string | stop "AI assistant:" |
|
||||||
| tfs_z | Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1) | float | tfs_z 1 |
|
| tfs_z | Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1) | float | tfs_z 1 |
|
||||||
| num_predict | Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context) | int | num_predict 42 |
|
| num_predict | Maximum number of tokens to predict when generating text. (Default: -1, infinite generation) | int | num_predict 42 |
|
||||||
| top_k | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | top_k 40 |
|
| top_k | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | top_k 40 |
|
||||||
| top_p | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | float | top_p 0.9 |
|
| top_p | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | float | top_p 0.9 |
|
||||||
| min_p | Alternative to the top_p, and aims to ensure a balance of quality and variety. The parameter *p* represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with *p*=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out. (Default: 0.0) | float | min_p 0.05 |
|
| min_p | Alternative to the top_p, and aims to ensure a balance of quality and variety. The parameter *p* represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with *p*=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out. (Default: 0.0) | float | min_p 0.05 |
|
||||||
|
|||||||
@@ -199,6 +199,8 @@ curl http://localhost:11434/v1/embeddings \
|
|||||||
- [x] `seed`
|
- [x] `seed`
|
||||||
- [x] `stop`
|
- [x] `stop`
|
||||||
- [x] `stream`
|
- [x] `stream`
|
||||||
|
- [x] `stream_options`
|
||||||
|
- [x] `include_usage`
|
||||||
- [x] `temperature`
|
- [x] `temperature`
|
||||||
- [x] `top_p`
|
- [x] `top_p`
|
||||||
- [x] `max_tokens`
|
- [x] `max_tokens`
|
||||||
@@ -227,6 +229,8 @@ curl http://localhost:11434/v1/embeddings \
|
|||||||
- [x] `seed`
|
- [x] `seed`
|
||||||
- [x] `stop`
|
- [x] `stop`
|
||||||
- [x] `stream`
|
- [x] `stream`
|
||||||
|
- [x] `stream_options`
|
||||||
|
- [x] `include_usage`
|
||||||
- [x] `temperature`
|
- [x] `temperature`
|
||||||
- [x] `top_p`
|
- [x] `top_p`
|
||||||
- [x] `max_tokens`
|
- [x] `max_tokens`
|
||||||
|
|||||||
@@ -153,6 +153,8 @@ var (
|
|||||||
Debug = Bool("OLLAMA_DEBUG")
|
Debug = Bool("OLLAMA_DEBUG")
|
||||||
// FlashAttention enables the experimental flash attention feature.
|
// FlashAttention enables the experimental flash attention feature.
|
||||||
FlashAttention = Bool("OLLAMA_FLASH_ATTENTION")
|
FlashAttention = Bool("OLLAMA_FLASH_ATTENTION")
|
||||||
|
// KvCacheType is the quantization type for the K/V cache.
|
||||||
|
KvCacheType = String("OLLAMA_KV_CACHE_TYPE")
|
||||||
// NoHistory disables readline history.
|
// NoHistory disables readline history.
|
||||||
NoHistory = Bool("OLLAMA_NOHISTORY")
|
NoHistory = Bool("OLLAMA_NOHISTORY")
|
||||||
// NoPrune disables pruning of model blobs on startup.
|
// NoPrune disables pruning of model blobs on startup.
|
||||||
@@ -234,6 +236,7 @@ func AsMap() map[string]EnvVar {
|
|||||||
ret := map[string]EnvVar{
|
ret := map[string]EnvVar{
|
||||||
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
||||||
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"},
|
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"},
|
||||||
|
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
|
||||||
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
|
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
|
||||||
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
|
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
|
||||||
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
|
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ make apply-patches
|
|||||||
|
|
||||||
**Pin to new base commit**
|
**Pin to new base commit**
|
||||||
|
|
||||||
To update to a newer base commit, select the upstream git tag or commit and update `llama/vendoring.env`
|
To update to a newer base commit, select the upstream git tag or commit and update `llama/vendoring`
|
||||||
|
|
||||||
#### Applying patches
|
#### Applying patches
|
||||||
|
|
||||||
|
|||||||
@@ -85,9 +85,12 @@ COMPILER inline get_compiler() {
|
|||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
_ "embed"
|
_ "embed"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/cgo"
|
"runtime/cgo"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -140,7 +143,7 @@ type ContextParams struct {
|
|||||||
c C.struct_llama_context_params
|
c C.struct_llama_context_params
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool) ContextParams {
|
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams {
|
||||||
params := C.llama_context_default_params()
|
params := C.llama_context_default_params()
|
||||||
params.n_ctx = C.uint(numCtx)
|
params.n_ctx = C.uint(numCtx)
|
||||||
params.n_batch = C.uint(batchSize)
|
params.n_batch = C.uint(batchSize)
|
||||||
@@ -149,9 +152,28 @@ func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, fla
|
|||||||
params.n_threads_batch = params.n_threads
|
params.n_threads_batch = params.n_threads
|
||||||
params.embeddings = C.bool(true)
|
params.embeddings = C.bool(true)
|
||||||
params.flash_attn = C.bool(flashAttention)
|
params.flash_attn = C.bool(flashAttention)
|
||||||
|
params.type_k = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
|
||||||
|
params.type_v = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
|
||||||
|
|
||||||
return ContextParams{c: params}
|
return ContextParams{c: params}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// kvCacheTypeFromStr converts a string cache type to the corresponding GGML type value
|
||||||
|
func kvCacheTypeFromStr(s string) C.enum_ggml_type {
|
||||||
|
if s == "" {
|
||||||
|
return C.GGML_TYPE_F16
|
||||||
|
}
|
||||||
|
|
||||||
|
switch s {
|
||||||
|
case "q8_0":
|
||||||
|
return C.GGML_TYPE_Q8_0
|
||||||
|
case "q4_0":
|
||||||
|
return C.GGML_TYPE_Q4_0
|
||||||
|
default:
|
||||||
|
return C.GGML_TYPE_F16
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type Context struct {
|
type Context struct {
|
||||||
c *C.struct_llama_context
|
c *C.struct_llama_context
|
||||||
numThreads int
|
numThreads int
|
||||||
@@ -680,3 +702,33 @@ func (s *SamplingContext) Sample(llamaContext *Context, idx int) int {
|
|||||||
func (s *SamplingContext) Accept(id int, applyGrammar bool) {
|
func (s *SamplingContext) Accept(id int, applyGrammar bool) {
|
||||||
C.gpt_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar))
|
C.gpt_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type JsonSchema struct {
|
||||||
|
Defs map[string]any `json:"$defs,omitempty"`
|
||||||
|
Properties map[string]any `json:"properties,omitempty"`
|
||||||
|
Required []string `json:"required,omitempty"`
|
||||||
|
Title string `json:"title,omitempty"`
|
||||||
|
Type string `json:"type,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (js JsonSchema) AsGrammar() string {
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := json.NewEncoder(&b).Encode(js); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
cStr := C.CString(b.String())
|
||||||
|
defer C.free(unsafe.Pointer(cStr))
|
||||||
|
|
||||||
|
// Allocate buffer for grammar output with reasonable size
|
||||||
|
const maxLen = 32768 // 32KB
|
||||||
|
buf := make([]byte, maxLen)
|
||||||
|
|
||||||
|
// Call C function to convert schema to grammar
|
||||||
|
length := C.schema_to_grammar(cStr, (*C.char)(unsafe.Pointer(&buf[0])), C.size_t(maxLen))
|
||||||
|
if length == 0 {
|
||||||
|
slog.Warn("unable to convert schema to grammar")
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(buf[:length])
|
||||||
|
}
|
||||||
|
|||||||
@@ -1 +1,70 @@
|
|||||||
package llama
|
package llama
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestJsonSchema(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
schema JsonSchema
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty schema",
|
||||||
|
schema: JsonSchema{
|
||||||
|
Type: "object",
|
||||||
|
},
|
||||||
|
expected: `array ::= "[" space ( value ("," space value)* )? "]" space
|
||||||
|
boolean ::= ("true" | "false") space
|
||||||
|
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||||
|
decimal-part ::= [0-9]{1,16}
|
||||||
|
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||||
|
null ::= "null" space
|
||||||
|
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||||
|
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||||
|
root ::= object
|
||||||
|
space ::= | " " | "\n" [ \t]{0,20}
|
||||||
|
string ::= "\"" char* "\"" space
|
||||||
|
value ::= object | array | string | number | boolean | null`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid schema with circular reference",
|
||||||
|
schema: JsonSchema{
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]any{
|
||||||
|
"self": map[string]any{
|
||||||
|
"$ref": "#", // Self reference
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "", // Should return empty string for invalid schema
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "schema with invalid type",
|
||||||
|
schema: JsonSchema{
|
||||||
|
Type: "invalid_type", // Invalid type
|
||||||
|
Properties: map[string]any{
|
||||||
|
"foo": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "", // Should return empty string for invalid schema
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := tc.schema.AsGrammar()
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(result), strings.TrimSpace(tc.expected)) {
|
||||||
|
if diff := cmp.Diff(tc.expected, result); diff != "" {
|
||||||
|
t.Fatalf("grammar mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -850,6 +850,7 @@ func (s *Server) loadModel(
|
|||||||
lpath multiLPath,
|
lpath multiLPath,
|
||||||
ppath string,
|
ppath string,
|
||||||
kvSize int,
|
kvSize int,
|
||||||
|
kvCacheType string,
|
||||||
flashAttention bool,
|
flashAttention bool,
|
||||||
threads int,
|
threads int,
|
||||||
multiUserCache bool,
|
multiUserCache bool,
|
||||||
@@ -862,7 +863,7 @@ func (s *Server) loadModel(
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention)
|
ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, kvCacheType)
|
||||||
s.lc, err = llama.NewContextWithModel(s.model, ctxParams)
|
s.lc, err = llama.NewContextWithModel(s.model, ctxParams)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
@@ -903,6 +904,7 @@ func main() {
|
|||||||
mainGpu := flag.Int("main-gpu", 0, "Main GPU")
|
mainGpu := flag.Int("main-gpu", 0, "Main GPU")
|
||||||
flashAttention := flag.Bool("flash-attn", false, "Enable flash attention")
|
flashAttention := flag.Bool("flash-attn", false, "Enable flash attention")
|
||||||
kvSize := flag.Int("ctx-size", 2048, "Context (or KV cache) size")
|
kvSize := flag.Int("ctx-size", 2048, "Context (or KV cache) size")
|
||||||
|
kvCacheType := flag.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
|
||||||
port := flag.Int("port", 8080, "Port to expose the server on")
|
port := flag.Int("port", 8080, "Port to expose the server on")
|
||||||
threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
||||||
verbose := flag.Bool("verbose", false, "verbose output (default: disabled)")
|
verbose := flag.Bool("verbose", false, "verbose output (default: disabled)")
|
||||||
@@ -970,7 +972,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
server.ready.Add(1)
|
server.ready.Add(1)
|
||||||
go server.loadModel(params, *mpath, lpaths, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache)
|
go server.loadModel(params, *mpath, lpaths, *ppath, *kvSize, *kvCacheType, *flashAttention, *threads, *multiUserCache)
|
||||||
|
|
||||||
server.cond = sync.NewCond(&server.mu)
|
server.cond = sync.NewCond(&server.mu)
|
||||||
|
|
||||||
|
|||||||
29
llama/sampling_ext.cpp
vendored
29
llama/sampling_ext.cpp
vendored
@@ -1,11 +1,13 @@
|
|||||||
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
|
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
|
||||||
#include "sampling.h"
|
#include "sampling.h"
|
||||||
#include "sampling_ext.h"
|
#include "sampling_ext.h"
|
||||||
|
#include "json-schema-to-grammar.h"
|
||||||
|
|
||||||
struct gpt_sampler *gpt_sampler_cinit(
|
struct gpt_sampler *gpt_sampler_cinit(
|
||||||
const struct llama_model *model, struct gpt_sampler_cparams *params)
|
const struct llama_model *model, struct gpt_sampler_cparams *params)
|
||||||
{
|
{
|
||||||
try {
|
try
|
||||||
|
{
|
||||||
gpt_sampler_params sparams;
|
gpt_sampler_params sparams;
|
||||||
sparams.top_k = params->top_k;
|
sparams.top_k = params->top_k;
|
||||||
sparams.top_p = params->top_p;
|
sparams.top_p = params->top_p;
|
||||||
@@ -24,7 +26,9 @@ struct gpt_sampler *gpt_sampler_cinit(
|
|||||||
sparams.seed = params->seed;
|
sparams.seed = params->seed;
|
||||||
sparams.grammar = params->grammar;
|
sparams.grammar = params->grammar;
|
||||||
return gpt_sampler_init(model, sparams);
|
return gpt_sampler_init(model, sparams);
|
||||||
} catch (const std::exception & err) {
|
}
|
||||||
|
catch (const std::exception &err)
|
||||||
|
{
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -54,3 +58,24 @@ void gpt_sampler_caccept(
|
|||||||
{
|
{
|
||||||
gpt_sampler_accept(sampler, id, apply_grammar);
|
gpt_sampler_accept(sampler, id, apply_grammar);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len)
|
||||||
|
{
|
||||||
|
try
|
||||||
|
{
|
||||||
|
nlohmann::json schema = nlohmann::json::parse(json_schema);
|
||||||
|
std::string grammar_str = json_schema_to_grammar(schema);
|
||||||
|
size_t len = grammar_str.length();
|
||||||
|
if (len >= max_len)
|
||||||
|
{
|
||||||
|
len = max_len - 1;
|
||||||
|
}
|
||||||
|
strncpy(grammar, grammar_str.c_str(), len);
|
||||||
|
return len;
|
||||||
|
}
|
||||||
|
catch (const std::exception &e)
|
||||||
|
{
|
||||||
|
strncpy(grammar, "", max_len - 1);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
2
llama/sampling_ext.h
vendored
2
llama/sampling_ext.h
vendored
@@ -47,6 +47,8 @@ extern "C"
|
|||||||
llama_token id,
|
llama_token id,
|
||||||
bool apply_grammar);
|
bool apply_grammar);
|
||||||
|
|
||||||
|
int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
36
llm/ggml.go
36
llm/ggml.go
@@ -360,7 +360,7 @@ func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
|||||||
}, offset, nil
|
}, offset, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm GGML) GraphSize(context, batch uint64) (kv, partialOffload, fullOffload uint64) {
|
func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
|
||||||
embedding := llm.KV().EmbeddingLength()
|
embedding := llm.KV().EmbeddingLength()
|
||||||
heads := llm.KV().HeadCount()
|
heads := llm.KV().HeadCount()
|
||||||
headsKV := llm.KV().HeadCountKV()
|
headsKV := llm.KV().HeadCountKV()
|
||||||
@@ -372,7 +372,8 @@ func (llm GGML) GraphSize(context, batch uint64) (kv, partialOffload, fullOffloa
|
|||||||
|
|
||||||
layers := llm.Tensors().Layers()
|
layers := llm.Tensors().Layers()
|
||||||
|
|
||||||
kv = 2 * context * llm.KV().BlockCount() * (embeddingHeadsK + embeddingHeadsV) * headsKV
|
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
||||||
|
kv = uint64(float64(context*llm.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||||
|
|
||||||
switch llm.KV().Architecture() {
|
switch llm.KV().Architecture() {
|
||||||
case "llama":
|
case "llama":
|
||||||
@@ -527,3 +528,34 @@ func (llm GGML) GraphSize(context, batch uint64) (kv, partialOffload, fullOffloa
|
|||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SupportsKVCacheType checks if the requested cache type is supported
|
||||||
|
func (ggml GGML) SupportsKVCacheType(cacheType string) bool {
|
||||||
|
validKVCacheTypes := []string{"f16", "q8_0", "q4_0"}
|
||||||
|
return slices.Contains(validKVCacheTypes, cacheType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SupportsFlashAttention checks if the model supports flash attention
|
||||||
|
func (ggml GGML) SupportsFlashAttention() bool {
|
||||||
|
_, isEmbedding := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]
|
||||||
|
if isEmbedding {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check head counts match and are non-zero
|
||||||
|
headCountK := ggml.KV().EmbeddingHeadCountK()
|
||||||
|
headCountV := ggml.KV().EmbeddingHeadCountV()
|
||||||
|
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
|
||||||
|
}
|
||||||
|
|
||||||
|
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
|
||||||
|
func kvCacheBytesPerElement(cacheType string) float64 {
|
||||||
|
switch cacheType {
|
||||||
|
case "q8_0":
|
||||||
|
return 1 // 1/2 of fp16
|
||||||
|
case "q4_0":
|
||||||
|
return 0.5 // 1/4 of fp16
|
||||||
|
default:
|
||||||
|
return 2 // f16 (default)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -123,7 +123,23 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string,
|
|||||||
slog.Warn("model missing blk.0 layer size")
|
slog.Warn("model missing blk.0 layer size")
|
||||||
}
|
}
|
||||||
|
|
||||||
kv, graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
|
fa := envconfig.FlashAttention() &&
|
||||||
|
discover.GetGPUInfo().FlashAttentionSupported() &&
|
||||||
|
ggml.SupportsFlashAttention()
|
||||||
|
|
||||||
|
var kvct string
|
||||||
|
if fa {
|
||||||
|
requested := strings.ToLower(envconfig.KvCacheType())
|
||||||
|
if requested != "" && ggml.SupportsKVCacheType(requested) {
|
||||||
|
kvct = requested
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kv, graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), kvct)
|
||||||
|
|
||||||
|
// KV is proportional to the number of layers
|
||||||
|
layerSize += kv / ggml.KV().BlockCount()
|
||||||
|
|
||||||
if graphPartialOffload == 0 {
|
if graphPartialOffload == 0 {
|
||||||
graphPartialOffload = ggml.KV().GQA() * kv / 6
|
graphPartialOffload = ggml.KV().GQA() * kv / 6
|
||||||
}
|
}
|
||||||
@@ -131,9 +147,6 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string,
|
|||||||
graphFullOffload = graphPartialOffload
|
graphFullOffload = graphPartialOffload
|
||||||
}
|
}
|
||||||
|
|
||||||
// KV is proportional to the number of layers
|
|
||||||
layerSize += kv / ggml.KV().BlockCount()
|
|
||||||
|
|
||||||
// on metal there's no partial offload overhead
|
// on metal there's no partial offload overhead
|
||||||
if gpus[0].Library == "metal" {
|
if gpus[0].Library == "metal" {
|
||||||
graphPartialOffload = graphFullOffload
|
graphPartialOffload = graphFullOffload
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
|
|
||||||
func TestEstimateGPULayers(t *testing.T) {
|
func TestEstimateGPULayers(t *testing.T) {
|
||||||
t.Setenv("OLLAMA_DEBUG", "1")
|
t.Setenv("OLLAMA_DEBUG", "1")
|
||||||
|
t.Setenv("OLLAMA_KV_CACHE_TYPE", "") // Ensure default f16
|
||||||
|
|
||||||
modelName := "dummy"
|
modelName := "dummy"
|
||||||
f, err := os.CreateTemp(t.TempDir(), modelName)
|
f, err := os.CreateTemp(t.TempDir(), modelName)
|
||||||
|
|||||||
@@ -214,15 +214,36 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
|
|||||||
params = append(params, "--threads", strconv.Itoa(defaultThreads))
|
params = append(params, "--threads", strconv.Itoa(defaultThreads))
|
||||||
}
|
}
|
||||||
|
|
||||||
flashAttnEnabled := envconfig.FlashAttention()
|
fa := envconfig.FlashAttention()
|
||||||
|
if fa && !gpus.FlashAttentionSupported() {
|
||||||
|
slog.Warn("flash attention enabled but not supported by gpu")
|
||||||
|
fa = false
|
||||||
|
}
|
||||||
|
|
||||||
for _, g := range gpus {
|
if fa && !ggml.SupportsFlashAttention() {
|
||||||
// only cuda (compute capability 7+) and metal support flash attention
|
slog.Warn("flash attention enabled but not supported by model")
|
||||||
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
|
fa = false
|
||||||
flashAttnEnabled = false
|
}
|
||||||
|
|
||||||
|
kvct := strings.ToLower(envconfig.KvCacheType())
|
||||||
|
|
||||||
|
if fa {
|
||||||
|
slog.Info("enabling flash attention")
|
||||||
|
params = append(params, "--flash-attn")
|
||||||
|
|
||||||
|
// Flash Attention also supports kv cache quantization
|
||||||
|
// Enable if the requested and kv cache type is supported by the model
|
||||||
|
if kvct != "" && ggml.SupportsKVCacheType(kvct) {
|
||||||
|
params = append(params, "--kv-cache-type", kvct)
|
||||||
|
} else {
|
||||||
|
slog.Warn("kv cache type not supported by model", "type", kvct)
|
||||||
|
}
|
||||||
|
} else if kvct != "" && kvct != "f16" {
|
||||||
|
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
|
||||||
}
|
}
|
||||||
|
|
||||||
// mmap has issues with partial offloading on metal
|
// mmap has issues with partial offloading on metal
|
||||||
|
for _, g := range gpus {
|
||||||
if g.Library == "metal" &&
|
if g.Library == "metal" &&
|
||||||
uint64(opts.NumGPU) > 0 &&
|
uint64(opts.NumGPU) > 0 &&
|
||||||
uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
|
uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
|
||||||
@@ -231,10 +252,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if flashAttnEnabled {
|
|
||||||
params = append(params, "--flash-attn")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Windows CUDA should not use mmap for best performance
|
// Windows CUDA should not use mmap for best performance
|
||||||
// Linux with a model larger than free space, mmap leads to thrashing
|
// Linux with a model larger than free space, mmap leads to thrashing
|
||||||
// For CPU loads we want the memory to be allocated, not FS cache
|
// For CPU loads we want the memory to be allocated, not FS cache
|
||||||
@@ -617,27 +634,22 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
|||||||
const jsonGrammar = `
|
const jsonGrammar = `
|
||||||
root ::= object
|
root ::= object
|
||||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||||
|
|
||||||
object ::=
|
object ::=
|
||||||
"{" ws (
|
"{" ws (
|
||||||
string ":" ws value
|
string ":" ws value
|
||||||
("," ws string ":" ws value)*
|
("," ws string ":" ws value)*
|
||||||
)? "}" ws
|
)? "}" ws
|
||||||
|
|
||||||
array ::=
|
array ::=
|
||||||
"[" ws (
|
"[" ws (
|
||||||
value
|
value
|
||||||
("," ws value)*
|
("," ws value)*
|
||||||
)? "]" ws
|
)? "]" ws
|
||||||
|
|
||||||
string ::=
|
string ::=
|
||||||
"\"" (
|
"\"" (
|
||||||
[^"\\\x7F\x00-\x1F] |
|
[^"\\\x7F\x00-\x1F] |
|
||||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||||
)* "\"" ws
|
)* "\"" ws
|
||||||
|
|
||||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||||
|
|
||||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||||
ws ::= ([ \t\n] ws)?
|
ws ::= ([ \t\n] ws)?
|
||||||
`
|
`
|
||||||
@@ -667,7 +679,7 @@ type completion struct {
|
|||||||
|
|
||||||
type CompletionRequest struct {
|
type CompletionRequest struct {
|
||||||
Prompt string
|
Prompt string
|
||||||
Format string
|
Format json.RawMessage
|
||||||
Images []ImageData
|
Images []ImageData
|
||||||
Options *api.Options
|
Options *api.Options
|
||||||
}
|
}
|
||||||
@@ -732,10 +744,22 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
return fmt.Errorf("unexpected server status: %s", status.ToString())
|
return fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Format == "json" {
|
// TODO (parthsareen): Move conversion to grammar with sampling logic
|
||||||
|
// API should do error handling for invalid formats
|
||||||
|
if req.Format != nil {
|
||||||
|
if strings.ToLower(strings.TrimSpace(string(req.Format))) == `"json"` {
|
||||||
request["grammar"] = jsonGrammar
|
request["grammar"] = jsonGrammar
|
||||||
if !strings.Contains(strings.ToLower(req.Prompt), "json") {
|
if !strings.Contains(strings.ToLower(req.Prompt), "json") {
|
||||||
slog.Warn("Prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
|
slog.Warn("prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
|
||||||
|
}
|
||||||
|
} else if schema, err := func() (llama.JsonSchema, error) {
|
||||||
|
var schema llama.JsonSchema
|
||||||
|
err := json.Unmarshal(req.Format, &schema)
|
||||||
|
return schema, err
|
||||||
|
}(); err == nil {
|
||||||
|
request["grammar"] = schema.AsGrammar()
|
||||||
|
} else {
|
||||||
|
slog.Warn(`format is neither a schema or "json"`, "format", req.Format)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -63,6 +63,11 @@ type Usage struct {
|
|||||||
|
|
||||||
type ResponseFormat struct {
|
type ResponseFormat struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
JsonSchema *JsonSchema `json:"json_schema,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type JsonSchema struct {
|
||||||
|
Schema map[string]any `json:"schema"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbedRequest struct {
|
type EmbedRequest struct {
|
||||||
@@ -70,10 +75,15 @@ type EmbedRequest struct {
|
|||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type StreamOptions struct {
|
||||||
|
IncludeUsage bool `json:"include_usage"`
|
||||||
|
}
|
||||||
|
|
||||||
type ChatCompletionRequest struct {
|
type ChatCompletionRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Messages []Message `json:"messages"`
|
Messages []Message `json:"messages"`
|
||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
|
StreamOptions *StreamOptions `json:"stream_options"`
|
||||||
MaxTokens *int `json:"max_tokens"`
|
MaxTokens *int `json:"max_tokens"`
|
||||||
Seed *int `json:"seed"`
|
Seed *int `json:"seed"`
|
||||||
Stop any `json:"stop"`
|
Stop any `json:"stop"`
|
||||||
@@ -102,6 +112,7 @@ type ChatCompletionChunk struct {
|
|||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
SystemFingerprint string `json:"system_fingerprint"`
|
SystemFingerprint string `json:"system_fingerprint"`
|
||||||
Choices []ChunkChoice `json:"choices"`
|
Choices []ChunkChoice `json:"choices"`
|
||||||
|
Usage *Usage `json:"usage,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
|
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
|
||||||
@@ -114,6 +125,7 @@ type CompletionRequest struct {
|
|||||||
Seed *int `json:"seed"`
|
Seed *int `json:"seed"`
|
||||||
Stop any `json:"stop"`
|
Stop any `json:"stop"`
|
||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
|
StreamOptions *StreamOptions `json:"stream_options"`
|
||||||
Temperature *float32 `json:"temperature"`
|
Temperature *float32 `json:"temperature"`
|
||||||
TopP float32 `json:"top_p"`
|
TopP float32 `json:"top_p"`
|
||||||
Suffix string `json:"suffix"`
|
Suffix string `json:"suffix"`
|
||||||
@@ -136,10 +148,12 @@ type CompletionChunk struct {
|
|||||||
Choices []CompleteChunkChoice `json:"choices"`
|
Choices []CompleteChunkChoice `json:"choices"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
SystemFingerprint string `json:"system_fingerprint"`
|
SystemFingerprint string `json:"system_fingerprint"`
|
||||||
|
Usage *Usage `json:"usage,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ToolCall struct {
|
type ToolCall struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
|
Index int `json:"index"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Function struct {
|
Function struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
@@ -191,6 +205,14 @@ func NewError(code int, message string) ErrorResponse {
|
|||||||
return ErrorResponse{Error{Type: etype, Message: message}}
|
return ErrorResponse{Error{Type: etype, Message: message}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toUsage(r api.ChatResponse) Usage {
|
||||||
|
return Usage{
|
||||||
|
PromptTokens: r.PromptEvalCount,
|
||||||
|
CompletionTokens: r.EvalCount,
|
||||||
|
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func toolCallId() string {
|
func toolCallId() string {
|
||||||
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||||
b := make([]byte, 8)
|
b := make([]byte, 8)
|
||||||
@@ -206,6 +228,7 @@ func toToolCalls(tc []api.ToolCall) []ToolCall {
|
|||||||
toolCalls[i].ID = toolCallId()
|
toolCalls[i].ID = toolCallId()
|
||||||
toolCalls[i].Type = "function"
|
toolCalls[i].Type = "function"
|
||||||
toolCalls[i].Function.Name = tc.Function.Name
|
toolCalls[i].Function.Name = tc.Function.Name
|
||||||
|
toolCalls[i].Index = tc.Function.Index
|
||||||
|
|
||||||
args, err := json.Marshal(tc.Function.Arguments)
|
args, err := json.Marshal(tc.Function.Arguments)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -239,11 +262,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
|||||||
return nil
|
return nil
|
||||||
}(r.DoneReason),
|
}(r.DoneReason),
|
||||||
}},
|
}},
|
||||||
Usage: Usage{
|
Usage: toUsage(r),
|
||||||
PromptTokens: r.PromptEvalCount,
|
|
||||||
CompletionTokens: r.EvalCount,
|
|
||||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -268,6 +287,14 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toUsageGenerate(r api.GenerateResponse) Usage {
|
||||||
|
return Usage{
|
||||||
|
PromptTokens: r.PromptEvalCount,
|
||||||
|
CompletionTokens: r.EvalCount,
|
||||||
|
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func toCompletion(id string, r api.GenerateResponse) Completion {
|
func toCompletion(id string, r api.GenerateResponse) Completion {
|
||||||
return Completion{
|
return Completion{
|
||||||
Id: id,
|
Id: id,
|
||||||
@@ -285,11 +312,7 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
|
|||||||
return nil
|
return nil
|
||||||
}(r.DoneReason),
|
}(r.DoneReason),
|
||||||
}},
|
}},
|
||||||
Usage: Usage{
|
Usage: toUsageGenerate(r),
|
||||||
PromptTokens: r.PromptEvalCount,
|
|
||||||
CompletionTokens: r.EvalCount,
|
|
||||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -480,9 +503,21 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|||||||
options["top_p"] = 1.0
|
options["top_p"] = 1.0
|
||||||
}
|
}
|
||||||
|
|
||||||
var format string
|
var format json.RawMessage
|
||||||
if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" {
|
if r.ResponseFormat != nil {
|
||||||
format = "json"
|
switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
|
||||||
|
// Support the old "json_object" type for OpenAI compatibility
|
||||||
|
case "json_object":
|
||||||
|
format = json.RawMessage(`"json"`)
|
||||||
|
case "json_schema":
|
||||||
|
if r.ResponseFormat.JsonSchema != nil {
|
||||||
|
schema, err := json.Marshal(r.ResponseFormat.JsonSchema.Schema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal json schema: %w", err)
|
||||||
|
}
|
||||||
|
format = schema
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &api.ChatRequest{
|
return &api.ChatRequest{
|
||||||
@@ -552,12 +587,14 @@ type BaseWriter struct {
|
|||||||
|
|
||||||
type ChatWriter struct {
|
type ChatWriter struct {
|
||||||
stream bool
|
stream bool
|
||||||
|
streamOptions *StreamOptions
|
||||||
id string
|
id string
|
||||||
BaseWriter
|
BaseWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
type CompleteWriter struct {
|
type CompleteWriter struct {
|
||||||
stream bool
|
stream bool
|
||||||
|
streamOptions *StreamOptions
|
||||||
id string
|
id string
|
||||||
BaseWriter
|
BaseWriter
|
||||||
}
|
}
|
||||||
@@ -601,7 +638,11 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
|||||||
|
|
||||||
// chat chunk
|
// chat chunk
|
||||||
if w.stream {
|
if w.stream {
|
||||||
d, err := json.Marshal(toChunk(w.id, chatResponse))
|
c := toChunk(w.id, chatResponse)
|
||||||
|
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||||
|
c.Usage = &Usage{}
|
||||||
|
}
|
||||||
|
d, err := json.Marshal(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@@ -613,6 +654,17 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if chatResponse.Done {
|
if chatResponse.Done {
|
||||||
|
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||||
|
u := toUsage(chatResponse)
|
||||||
|
d, err := json.Marshal(ChatCompletionChunk{Choices: []ChunkChoice{}, Usage: &u})
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
@@ -650,7 +702,11 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
|||||||
|
|
||||||
// completion chunk
|
// completion chunk
|
||||||
if w.stream {
|
if w.stream {
|
||||||
d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
|
c := toCompleteChunk(w.id, generateResponse)
|
||||||
|
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||||
|
c.Usage = &Usage{}
|
||||||
|
}
|
||||||
|
d, err := json.Marshal(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@@ -662,6 +718,17 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if generateResponse.Done {
|
if generateResponse.Done {
|
||||||
|
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||||
|
u := toUsageGenerate(generateResponse)
|
||||||
|
d, err := json.Marshal(CompletionChunk{Choices: []CompleteChunkChoice{}, Usage: &u})
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
@@ -827,6 +894,7 @@ func CompletionsMiddleware() gin.HandlerFunc {
|
|||||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
stream: req.Stream,
|
stream: req.Stream,
|
||||||
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
||||||
|
streamOptions: req.StreamOptions,
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Writer = w
|
c.Writer = w
|
||||||
@@ -909,6 +977,7 @@ func ChatMiddleware() gin.HandlerFunc {
|
|||||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
stream: req.Stream,
|
stream: req.Stream,
|
||||||
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
||||||
|
streamOptions: req.StreamOptions,
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Writer = w
|
c.Writer = w
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
@@ -107,7 +108,46 @@ func TestChatMiddleware(t *testing.T) {
|
|||||||
"presence_penalty": 5.0,
|
"presence_penalty": 5.0,
|
||||||
"top_p": 6.0,
|
"top_p": 6.0,
|
||||||
},
|
},
|
||||||
Format: "json",
|
Format: json.RawMessage(`"json"`),
|
||||||
|
Stream: &True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chat handler with streaming usage",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
],
|
||||||
|
"stream": true,
|
||||||
|
"stream_options": {"include_usage": true},
|
||||||
|
"max_tokens": 999,
|
||||||
|
"seed": 123,
|
||||||
|
"stop": ["\n", "stop"],
|
||||||
|
"temperature": 3.0,
|
||||||
|
"frequency_penalty": 4.0,
|
||||||
|
"presence_penalty": 5.0,
|
||||||
|
"top_p": 6.0,
|
||||||
|
"response_format": {"type": "json_object"}
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "Hello",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"num_predict": 999.0, // float because JSON doesn't distinguish between float and int
|
||||||
|
"seed": 123.0,
|
||||||
|
"stop": []any{"\n", "stop"},
|
||||||
|
"temperature": 3.0,
|
||||||
|
"frequency_penalty": 4.0,
|
||||||
|
"presence_penalty": 5.0,
|
||||||
|
"top_p": 6.0,
|
||||||
|
},
|
||||||
|
Format: json.RawMessage(`"json"`),
|
||||||
Stream: &True,
|
Stream: &True,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -195,7 +235,86 @@ func TestChatMiddleware(t *testing.T) {
|
|||||||
Stream: &False,
|
Stream: &False,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "chat handler with streaming tools",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather like in Paris?"}
|
||||||
|
],
|
||||||
|
"stream": true,
|
||||||
|
"tools": [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the current weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"required": ["location"],
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "What's the weather like in Paris?",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Tools: []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get the current weather",
|
||||||
|
Parameters: struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Required []string `json:"required"`
|
||||||
|
Properties map[string]struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
} `json:"properties"`
|
||||||
|
}{
|
||||||
|
Type: "object",
|
||||||
|
Required: []string{"location"},
|
||||||
|
Properties: map[string]struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
}{
|
||||||
|
"location": {
|
||||||
|
Type: "string",
|
||||||
|
Description: "The city and state",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
Type: "string",
|
||||||
|
Enum: []string{"celsius", "fahrenheit"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
},
|
||||||
|
Stream: &True,
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "chat handler error forwarding",
|
name: "chat handler error forwarding",
|
||||||
body: `{
|
body: `{
|
||||||
@@ -237,13 +356,13 @@ func TestChatMiddleware(t *testing.T) {
|
|||||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
|
||||||
t.Fatal("requests did not match")
|
t.Fatalf("requests did not match: %+v", diff)
|
||||||
}
|
}
|
||||||
|
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
||||||
if !reflect.DeepEqual(tc.err, errResp) {
|
t.Fatalf("errors did not match for %s:\n%s", tc.name, diff)
|
||||||
t.Fatal("errors did not match")
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -283,6 +402,55 @@ func TestCompletionsMiddleware(t *testing.T) {
|
|||||||
Stream: &False,
|
Stream: &False,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "completions handler stream",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": "Hello",
|
||||||
|
"stream": true,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"stop": ["\n", "stop"],
|
||||||
|
"suffix": "suffix"
|
||||||
|
}`,
|
||||||
|
req: api.GenerateRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
Options: map[string]any{
|
||||||
|
"frequency_penalty": 0.0,
|
||||||
|
"presence_penalty": 0.0,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"stop": []any{"\n", "stop"},
|
||||||
|
},
|
||||||
|
Suffix: "suffix",
|
||||||
|
Stream: &True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "completions handler stream with usage",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": "Hello",
|
||||||
|
"stream": true,
|
||||||
|
"stream_options": {"include_usage": true},
|
||||||
|
"temperature": 0.8,
|
||||||
|
"stop": ["\n", "stop"],
|
||||||
|
"suffix": "suffix"
|
||||||
|
}`,
|
||||||
|
req: api.GenerateRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
Options: map[string]any{
|
||||||
|
"frequency_penalty": 0.0,
|
||||||
|
"presence_penalty": 0.0,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"stop": []any{"\n", "stop"},
|
||||||
|
},
|
||||||
|
Suffix: "suffix",
|
||||||
|
Stream: &True,
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "completions handler error forwarding",
|
name: "completions handler error forwarding",
|
||||||
body: `{
|
body: `{
|
||||||
|
|||||||
@@ -148,10 +148,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Format != "" && req.Format != "json" {
|
if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
|
|
||||||
return
|
|
||||||
} else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -251,6 +248,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if req.Context != nil {
|
if req.Context != nil {
|
||||||
|
slog.Warn("the context field is deprecated and will be removed in a future version of Ollama")
|
||||||
s, err := r.Detokenize(c.Request.Context(), req.Context)
|
s, err := r.Detokenize(c.Request.Context(), req.Context)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
@@ -1469,7 +1467,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
var hasToolCalls bool
|
var toolCallIndex int = 0
|
||||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Images: images,
|
Images: images,
|
||||||
@@ -1509,16 +1507,19 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
sb.WriteString(r.Content)
|
sb.WriteString(r.Content)
|
||||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||||
res.Message.ToolCalls = toolCalls
|
res.Message.ToolCalls = toolCalls
|
||||||
|
for i := range toolCalls {
|
||||||
|
toolCalls[i].Function.Index = toolCallIndex
|
||||||
|
toolCallIndex++
|
||||||
|
}
|
||||||
res.Message.Content = ""
|
res.Message.Content = ""
|
||||||
sb.Reset()
|
sb.Reset()
|
||||||
hasToolCalls = true
|
|
||||||
ch <- res
|
ch <- res
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Done {
|
if r.Done {
|
||||||
// Send any remaining content if no tool calls were detected
|
// Send any remaining content if no tool calls were detected
|
||||||
if !hasToolCalls {
|
if toolCallIndex == 0 {
|
||||||
res.Message.Content = sb.String()
|
res.Message.Content = sb.String()
|
||||||
}
|
}
|
||||||
ch <- res
|
ch <- res
|
||||||
|
|||||||
Reference in New Issue
Block a user