Compare commits

...

16 Commits

Author SHA1 Message Date
Josh Yan
4da5d5beaa lint 2024-08-28 10:23:41 -07:00
Josh Yan
cc17b02b23 update 2024-08-28 09:58:23 -07:00
Josh Yan
73d69bc90b remove types 2024-08-27 16:45:07 -07:00
Josh Yan
9bc42f532b rmv api type 2024-08-27 16:45:07 -07:00
Josh Yan
07c0f66f5e rm print 2024-08-27 16:45:04 -07:00
Josh Yan
4a7bfca902 change progress msg 2024-08-27 16:44:38 -07:00
Josh Yan
04f2154505 fixed cgo 2024-08-27 16:44:38 -07:00
Josh Yan
de9b21b472 quantize progress 2024-08-27 16:44:32 -07:00
Daniel Hiltgen
93ea9240ae Move ollama executable out of bin dir (#6535) 2024-08-27 16:19:00 -07:00
Patrick Devine
d13c3daa0b add safetensors to the modelfile docs (#6532) 2024-08-27 14:46:47 -07:00
Patrick Devine
1713eddcd0 Fix import image width (#6528) 2024-08-27 14:19:47 -07:00
Daniel Hiltgen
4e1c4f6e0b Update manual instructions with discrete ROCm bundle (#6445) 2024-08-27 13:42:28 -07:00
Sean Khatiri
397cae7962 llm: fix typo in comment (#6530) 2024-08-27 13:28:29 -07:00
Patrick Devine
1c70a00f71 adjust image sizes 2024-08-27 11:15:25 -07:00
Patrick Devine
ac80010db8 update the import docs (#6104) 2024-08-26 19:57:26 -07:00
Jeffrey Morgan
47fa0839b9 server: clean up route names for consistency (#6524) 2024-08-26 19:36:11 -07:00
24 changed files with 1606 additions and 126 deletions

1
.gitattributes vendored
View File

@@ -1,3 +1,4 @@
llm/ext_server/* linguist-vendored llm/ext_server/* linguist-vendored
llm/*.h linguist-vendored
* text=auto * text=auto
*.go text eol=lf *.go text eol=lf

View File

@@ -87,7 +87,7 @@ DialogFontSize=12
[Files] [Files]
Source: ".\app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ; Flags: ignoreversion 64bit Source: ".\app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ; Flags: ignoreversion 64bit
Source: "..\ollama.exe"; DestDir: "{app}\bin"; Flags: ignoreversion 64bit Source: "..\ollama.exe"; DestDir: "{app}"; Flags: ignoreversion 64bit
Source: "..\dist\windows-{#ARCH}\lib\ollama\runners\*"; DestDir: "{app}\lib\ollama\runners"; Flags: ignoreversion 64bit recursesubdirs Source: "..\dist\windows-{#ARCH}\lib\ollama\runners\*"; DestDir: "{app}\lib\ollama\runners"; Flags: ignoreversion 64bit recursesubdirs
Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion
Source: ".\assets\app.ico"; DestDir: "{app}"; Flags: ignoreversion Source: ".\assets\app.ico"; DestDir: "{app}"; Flags: ignoreversion
@@ -99,7 +99,7 @@ Name: "{userstartup}\{#MyAppName}"; Filename: "{app}\{#MyAppExeName}"; IconFilen
Name: "{userprograms}\{#MyAppName}"; Filename: "{app}\{#MyAppExeName}"; IconFilename: "{app}\app.ico" Name: "{userprograms}\{#MyAppName}"; Filename: "{app}\{#MyAppExeName}"; IconFilename: "{app}\app.ico"
[Run] [Run]
Filename: "{cmd}"; Parameters: "/C set PATH={app}\bin;%PATH% & ""{app}\{#MyAppExeName}"""; Flags: postinstall nowait runhidden Filename: "{cmd}"; Parameters: "/C set PATH={app};%PATH% & ""{app}\{#MyAppExeName}"""; Flags: postinstall nowait runhidden
[UninstallRun] [UninstallRun]
; Filename: "{cmd}"; Parameters: "/C ""taskkill /im ''{#MyAppExeName}'' /f /t"; Flags: runhidden ; Filename: "{cmd}"; Parameters: "/C ""taskkill /im ''{#MyAppExeName}'' /f /t"; Flags: runhidden
@@ -134,8 +134,8 @@ SetupAppRunningError=Another Ollama installer is running.%n%nPlease cancel or fi
[Registry] [Registry]
Root: HKCU; Subkey: "Environment"; \ Root: HKCU; Subkey: "Environment"; \
ValueType: expandsz; ValueName: "Path"; ValueData: "{olddata};{app}\bin"; \ ValueType: expandsz; ValueName: "Path"; ValueData: "{olddata};{app}"; \
Check: NeedsAddPath('{app}\bin') Check: NeedsAddPath('{app}')
[Code] [Code]

View File

@@ -124,6 +124,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
} }
bars := make(map[string]*progress.Bar) bars := make(map[string]*progress.Bar)
var quantizeSpin *progress.Spinner
fn := func(resp api.ProgressResponse) error { fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" { if resp.Digest != "" {
spinner.Stop() spinner.Stop()
@@ -136,6 +137,15 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
} }
bar.Set(resp.Completed) bar.Set(resp.Completed)
} else if strings.Contains(resp.Status, "quantizing") {
spinner.Stop()
if quantizeSpin != nil {
quantizeSpin.SetMessage(resp.Status)
} else {
quantizeSpin = progress.NewSpinner(resp.Status)
p.Add("quantize", quantizeSpin)
}
} else if status != resp.Status { } else if status != resp.Status {
spinner.Stop() spinner.Stop()

BIN
docs/images/ollama-keys.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 150 KiB

BIN
docs/images/signup.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 80 KiB

View File

@@ -1,44 +1,129 @@
# Import # Importing a model
GGUF models and select Safetensors models can be imported directly into Ollama. ## Table of Contents
## Import GGUF * [Importing a Safetensors adapter](#Importing-a-fine-tuned-adapter-from-Safetensors-weights)
* [Importing a Safetensors model](#Importing-a-model-from-Safetensors-weights)
* [Importing a GGUF file](#Importing-a-GGUF-based-model-or-adapter)
* [Sharing models on ollama.com](#Sharing-your-model-on-ollamacom)
A binary GGUF file can be imported directly into Ollama through a Modelfile. ## Importing a fine tuned adapter from Safetensors weights
First, create a `Modelfile` with a `FROM` command pointing at the base model you used for fine tuning, and an `ADAPTER` command which points to the directory with your Safetensors adapter:
```dockerfile ```dockerfile
FROM /path/to/file.gguf FROM <base model name>
ADAPTER /path/to/safetensors/adapter/directory
``` ```
## Import Safetensors Make sure that you use the same base model in the `FROM` command as you used to create the adapter otherwise you will get erratic results. Most frameworks use different quantization methods, so it's best to use non-quantized (i.e. non-QLoRA) adapters. If your adapter is in the same directory as your `Modelfile`, use `ADAPTER .` to specify the adapter path.
If the model being imported is one of these architectures, it can be imported directly into Ollama through a Modelfile: Now run `ollama create` from the directory where the `Modelfile` was created:
- LlamaForCausalLM ```bash
- MistralForCausalLM ollama create my-model
- MixtralForCausalLM ```
- GemmaForCausalLM
- Phi3ForCausalLM Lastly, test the model:
```bash
ollama run my-model
```
Ollama supports importing adapters based on several different model architectures including:
* Llama (including Llama 2, Llama 3, and Llama 3.1);
* Mistral (including Mistral 1, Mistral 2, and Mixtral); and
* Gemma (including Gemma 1 and Gemma 2)
You can create the adapter using a fine tuning framework or tool which can output adapters in the Safetensors format, such as:
* Hugging Face [fine tuning framework] (https://huggingface.co/docs/transformers/en/training)
* [Unsloth](https://github.com/unslothai/unsloth)
* [MLX](https://github.com/ml-explore/mlx)
## Importing a model from Safetensors weights
First, create a `Modelfile` with a `FROM` command which points to the directory containing your Safetensors weights:
```dockerfile ```dockerfile
FROM /path/to/safetensors/directory FROM /path/to/safetensors/directory
``` ```
For architectures not directly convertable by Ollama, see llama.cpp's [guide](https://github.com/ggerganov/llama.cpp/blob/master/README.md#prepare-and-quantize) on conversion. After conversion, see [Import GGUF](#import-gguf). If you create the Modelfile in the same directory as the weights, you can use the command `FROM .`.
## Automatic Quantization Now run the `ollama create` command from the directory where you created the `Modelfile`:
> [!NOTE] ```shell
> Automatic quantization requires v0.1.35 or higher. ollama create my-model
```
Ollama is capable of quantizing FP16 or FP32 models to any of the supported quantizations with the `-q/--quantize` flag in `ollama create`. Lastly, test the model:
```shell
ollama run my-model
```
Ollama supports importing models for several different architectures including:
* Llama (including Llama 2, Llama 3, and Llama 3.1);
* Mistral (including Mistral 1, Mistral 2, and Mixtral);
* Gemma (including Gemma 1 and Gemma 2); and
* Phi3
This includes importing foundation models as well as any fine tuned models which which have been _fused_ with a foundation model.
## Importing a GGUF based model or adapter
If you have a GGUF based model or adapter it is possible to import it into Ollama. You can obtain a GGUF model or adapter by:
* converting a Safetensors model with the `convert_hf_to_gguf.py` from Llama.cpp;
* converting a Safetensors adapter with the `convert_lora_to_gguf.py` from Llama.cpp; or
* downloading a model or adapter from a place such as HuggingFace
To import a GGUF model, create a `Modelfile` containg:
```dockerfile
FROM /path/to/file.gguf
```
For a GGUF adapter, create the `Modelfile` with:
```dockerfile
FROM <model name>
ADAPTER /path/to/file.gguf
```
When importing a GGUF adapter, it's important to use the same base model as the base model that the adapter was created with. You can use:
* a model from Ollama
* a GGUF file
* a Safetensors based model
Once you have created your `Modelfile`, use the `ollama create` command to build the model.
```shell
ollama create my-model
```
## Quantizing a Model
Quantizing a model allows you to run models faster and with less memory consumption but at reduced accuracy. This allows you to run a model on more modest hardware.
Ollama can quantize FP16 and FP32 based models into different quantization levels using the `-q/--quantize` flag with the `ollama create` command.
First, create a Modelfile with the FP16 or FP32 based model you wish to quantize.
```dockerfile ```dockerfile
FROM /path/to/my/gemma/f16/model FROM /path/to/my/gemma/f16/model
``` ```
Use `ollama create` to then create the quantized model.
```shell ```shell
$ ollama create -q Q4_K_M mymodel $ ollama create --quantize q4_K_M mymodel
transferring model data transferring model data
quantizing F16 model to Q4_K_M quantizing F16 model to Q4_K_M
creating new layer sha256:735e246cc1abfd06e9cdcf95504d6789a6cd1ad7577108a70d9902fef503c1bd creating new layer sha256:735e246cc1abfd06e9cdcf95504d6789a6cd1ad7577108a70d9902fef503c1bd
@@ -49,42 +134,53 @@ success
### Supported Quantizations ### Supported Quantizations
- `Q4_0` - `q4_0`
- `Q4_1` - `q4_1`
- `Q5_0` - `q5_0`
- `Q5_1` - `q5_1`
- `Q8_0` - `q8_0`
#### K-means Quantizations #### K-means Quantizations
- `Q3_K_S` - `q3_K_S`
- `Q3_K_M` - `q3_K_M`
- `Q3_K_L` - `q3_K_L`
- `Q4_K_S` - `q4_K_S`
- `Q4_K_M` - `q4_K_M`
- `Q5_K_S` - `q5_K_S`
- `Q5_K_M` - `q5_K_M`
- `Q6_K` - `q6_K`
## Template Detection
> [!NOTE] ## Sharing your model on ollama.com
> Template detection requires v0.1.42 or higher.
Ollama uses model metadata, specifically `tokenizer.chat_template`, to automatically create a template appropriate for the model you're importing. You can share any model you have created by pushing it to [ollama.com](https://ollama.com) so that other users can try it out.
```dockerfile First, use your browser to go to the [Ollama Sign-Up](https://ollama.com/signup) page. If you already have an account, you can skip this step.
FROM /path/to/my/gemma/model
``` <img src="images/signup.png" alt="Sign-Up" width="40%">
The `Username` field will be used as part of your model's name (e.g. `jmorganca/mymodel`), so make sure you are comfortable with the username that you have selected.
Now that you have created an account and are signed-in, go to the [Ollama Keys Settings](https://ollama.com/settings/keys) page.
Follow the directions on the page to determine where your Ollama Public Key is located.
<img src="images/ollama-keys.png" alt="Ollama Keys" width="80%">
Click on the `Add Ollama Public Key` button, and copy and paste the contents of your Ollama Public Key into the text field.
To push a model to [ollama.com](https://ollama.com), first make sure that it is named correctly with your username. You may have to use the `ollama cp` command to copy
your model to give it the correct name. Once you're happy with your model's name, use the `ollama push` command to push it to [ollama.com](https://ollama.com).
```shell ```shell
$ ollama create mymodel ollama cp mymodel myuser/mymodel
transferring model data ollama push myuser/mymodel
using autodetected template gemma-instruct ```
creating new layer sha256:baa2a0edc27d19cc6b7537578a9a7ba1a4e3214dc185ed5ae43692b319af7b84
creating new layer sha256:ba66c3309914dbef07e5149a648fd1877f030d337a4f240d444ea335008943cb Once your model has been pushed, other users can pull and run it by using the command:
writing manifest
success ```shell
ollama run myuser/mymodel
``` ```
Defining a template in the Modelfile will disable this feature which may be useful if you want to use a different template than the autodetected one.

View File

@@ -28,6 +28,11 @@ Download and extract the Linux package:
curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz | sudo tar zx -C /usr curl -fsSL https://ollama.com/download/ollama-linux-amd64.tgz | sudo tar zx -C /usr
``` ```
If you have an AMD GPU, also download and extract the ROCm package into the same location
```bash
curl -fsSL https://ollama.com/download/ollama-linux-amd64-rocm.tgz | sudo tar zx -C /usr
```
### Adding Ollama as a startup service (recommended) ### Adding Ollama as a startup service (recommended)
Create a user for Ollama: Create a user for Ollama:

View File

@@ -11,8 +11,9 @@ A model file is the blueprint to create and share models with Ollama.
- [Examples](#examples) - [Examples](#examples)
- [Instructions](#instructions) - [Instructions](#instructions)
- [FROM (Required)](#from-required) - [FROM (Required)](#from-required)
- [Build from llama3](#build-from-llama3) - [Build from llama3.1](#build-from-llama31)
- [Build from a bin file](#build-from-a-bin-file) - [Build from a Safetensors model](#build-from-a-safetensors-model)
- [Build from a GGUF file](#build-from-a-gguf-file)
- [PARAMETER](#parameter) - [PARAMETER](#parameter)
- [Valid Parameters and Values](#valid-parameters-and-values) - [Valid Parameters and Values](#valid-parameters-and-values)
- [TEMPLATE](#template) - [TEMPLATE](#template)
@@ -99,22 +100,39 @@ The `FROM` instruction defines the base model to use when creating a model.
FROM <model name>:<tag> FROM <model name>:<tag>
``` ```
#### Build from llama3 #### Build from llama3.1
```modelfile ```modelfile
FROM llama3 FROM llama3.1
``` ```
A list of available base models: A list of available base models:
<https://github.com/ollama/ollama#model-library> <https://github.com/ollama/ollama#model-library>
Additional models can be found at:
<https://ollama.com/library>
#### Build from a `bin` file #### Build from a Safetensors model
```modelfile
FROM <model directory>
```
The model directory should contain the Safetensors weights for a supported architecture.
Currently supported model architectures:
* Llama (including Llama 2, Llama 3, and Llama 3.1)
* Mistral (including Mistral 1, Mistral 2, and Mixtral)
* Gemma (including Gemma 1 and Gemma 2)
* Phi3
#### Build from a GGUF file
```modelfile ```modelfile
FROM ./ollama-model.bin FROM ./ollama-model.bin
``` ```
This bin file location should be specified as an absolute path or relative to the `Modelfile` location. The GGUF bin file location should be specified as an absolute path or relative to the `Modelfile` location.
### PARAMETER ### PARAMETER
@@ -174,7 +192,20 @@ SYSTEM """<system message>"""
### ADAPTER ### ADAPTER
The `ADAPTER` instruction is an optional instruction that specifies any LoRA adapter that should apply to the base model. The value of this instruction should be an absolute path or a path relative to the Modelfile and the file must be in a GGML file format. The adapter should be tuned from the base model otherwise the behaviour is undefined. The `ADAPTER` instruction specifies a fine tuned LoRA adapter that should apply to the base model. The value of the adapter should be an absolute path or a path relative to the Modelfile. The base model should be specified with a `FROM` instruction. If the base model is not the same as the base model that the adapter was tuned from the behaviour will be erratic.
#### Safetensor adapter
```modelfile
ADAPTER <path to safetensor adapter>
```
Currently supported Safetensor adapters:
* Llama (including Llama 2, Llama 3, and Llama 3.1)
* Mistral (including Mistral 1, Mistral 2, and Mixtral)
* Gemma (including Gemma 1 and Gemma 2)
#### GGUF adapter
```modelfile ```modelfile
ADAPTER ./ollama-lora.bin ADAPTER ./ollama-lora.bin

View File

@@ -190,7 +190,7 @@ func RunnersDir() (p string) {
} }
var paths []string var paths []string
for _, root := range []string{filepath.Dir(exe), filepath.Join(filepath.Dir(exe), ".."), cwd} { for _, root := range []string{filepath.Dir(exe), filepath.Join(filepath.Dir(exe), LibRelativeToExe()), cwd} {
paths = append(paths, paths = append(paths,
root, root,
filepath.Join(root, runtime.GOOS+"-"+runtime.GOARCH), filepath.Join(root, runtime.GOOS+"-"+runtime.GOARCH),
@@ -282,3 +282,12 @@ func Values() map[string]string {
func Var(key string) string { func Var(key string) string {
return strings.Trim(strings.TrimSpace(os.Getenv(key)), "\"'") return strings.Trim(strings.TrimSpace(os.Getenv(key)), "\"'")
} }
// On windows, we keep the binary at the top directory, but
// other platforms use a "bin" directory, so this returns ".."
func LibRelativeToExe() string {
if runtime.GOOS == "windows" {
return "."
}
return ".."
}

View File

@@ -9,6 +9,8 @@ import (
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings" "strings"
"github.com/ollama/ollama/envconfig"
) )
// Determine if the given ROCm lib directory is usable by checking for existence of some glob patterns // Determine if the given ROCm lib directory is usable by checking for existence of some glob patterns
@@ -54,7 +56,7 @@ func commonAMDValidateLibDir() (string, error) {
// Installer payload location if we're running the installed binary // Installer payload location if we're running the installed binary
exe, err := os.Executable() exe, err := os.Executable()
if err == nil { if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), "..", "lib", "ollama") rocmTargetDir := filepath.Join(filepath.Dir(exe), envconfig.LibRelativeToExe(), "lib", "ollama")
if rocmLibUsable(rocmTargetDir) { if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir) slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil return rocmTargetDir, nil

View File

@@ -153,7 +153,7 @@ func AMDValidateLibDir() (string, error) {
// Installer payload (if we're running from some other location) // Installer payload (if we're running from some other location)
localAppData := os.Getenv("LOCALAPPDATA") localAppData := os.Getenv("LOCALAPPDATA")
appDir := filepath.Join(localAppData, "Programs", "Ollama") appDir := filepath.Join(localAppData, "Programs", "Ollama")
rocmTargetDir := filepath.Join(appDir, "..", "lib", "ollama") rocmTargetDir := filepath.Join(appDir, envconfig.LibRelativeToExe(), "lib", "ollama")
if rocmLibUsable(rocmTargetDir) { if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ollama installed ROCm at " + rocmTargetDir) slog.Debug("detected ollama installed ROCm at " + rocmTargetDir)
return rocmTargetDir, nil return rocmTargetDir, nil

View File

@@ -653,7 +653,7 @@ func LibraryDir() string {
slog.Warn("failed to lookup working directory", "error", err) slog.Warn("failed to lookup working directory", "error", err)
} }
// Scan for any of our dependeices, and pick first match // Scan for any of our dependeices, and pick first match
for _, root := range []string{filepath.Dir(appExe), filepath.Join(filepath.Dir(appExe), ".."), cwd} { for _, root := range []string{filepath.Dir(appExe), filepath.Join(filepath.Dir(appExe), envconfig.LibRelativeToExe()), cwd} {
libDep := filepath.Join("lib", "ollama") libDep := filepath.Join("lib", "ollama")
if _, err := os.Stat(filepath.Join(root, libDep)); err == nil { if _, err := os.Stat(filepath.Join(root, libDep)); err == nil {
return filepath.Join(root, libDep) return filepath.Join(root, libDep)

1227
llm/llama.h vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
package llm package llm
// #cgo CFLAGS: -Illama.cpp -Illama.cpp/include -Illama.cpp/ggml/include // #cgo CPPFLAGS: -Illama.cpp/ggml/include
// #cgo LDFLAGS: -lllama -lggml -lstdc++ -lpthread // #cgo LDFLAGS: -lllama -lggml -lstdc++ -lpthread
// #cgo darwin,arm64 LDFLAGS: -L${SRCDIR}/build/darwin/arm64_static -L${SRCDIR}/build/darwin/arm64_static/src -L${SRCDIR}/build/darwin/arm64_static/ggml/src -framework Accelerate -framework Metal // #cgo darwin,arm64 LDFLAGS: -L${SRCDIR}/build/darwin/arm64_static -L${SRCDIR}/build/darwin/arm64_static/src -L${SRCDIR}/build/darwin/arm64_static/ggml/src -framework Accelerate -framework Metal
// #cgo darwin,amd64 LDFLAGS: -L${SRCDIR}/build/darwin/x86_64_static -L${SRCDIR}/build/darwin/x86_64_static/src -L${SRCDIR}/build/darwin/x86_64_static/ggml/src // #cgo darwin,amd64 LDFLAGS: -L${SRCDIR}/build/darwin/x86_64_static -L${SRCDIR}/build/darwin/x86_64_static/src -L${SRCDIR}/build/darwin/x86_64_static/ggml/src
@@ -9,12 +9,24 @@ package llm
// #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/linux/x86_64_static -L${SRCDIR}/build/linux/x86_64_static/src -L${SRCDIR}/build/linux/x86_64_static/ggml/src // #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/linux/x86_64_static -L${SRCDIR}/build/linux/x86_64_static/src -L${SRCDIR}/build/linux/x86_64_static/ggml/src
// #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src // #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src
// #include <stdlib.h> // #include <stdlib.h>
// #include <stdatomic.h>
// #include "llama.h" // #include "llama.h"
// bool update_quantize_progress(float progress, void* data) {
// atomic_int* atomicData = (atomic_int*)data;
// int intProgress = *((int*)&progress);
// atomic_store(atomicData, intProgress);
// return true;
// }
import "C" import "C"
import ( import (
"errors" "errors"
"fmt"
"sync/atomic"
"time"
"unsafe" "unsafe"
"github.com/ollama/ollama/api"
) )
// SystemInfo is an unused example of calling llama.cpp functions using CGo // SystemInfo is an unused example of calling llama.cpp functions using CGo
@@ -22,17 +34,49 @@ func SystemInfo() string {
return C.GoString(C.llama_print_system_info()) return C.GoString(C.llama_print_system_info())
} }
func Quantize(infile, outfile string, ftype fileType) error { func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressResponse), tensorCount int) error {
cinfile := C.CString(infile) cinfile := C.CString(infile)
defer C.free(unsafe.Pointer(cinfile)) defer C.free(unsafe.Pointer(cinfile))
coutfile := C.CString(outfile) coutfile := C.CString(outfile)
defer C.free(unsafe.Pointer(coutfile)) defer C.free(unsafe.Pointer(coutfile))
params := C.llama_model_quantize_default_params() params := C.llama_model_quantize_default_params()
params.nthread = -1 params.nthread = -1
params.ftype = ftype.Value() params.ftype = ftype.Value()
// Initialize "global" to store progress
store := (*int32)(C.malloc(C.sizeof_int))
defer C.free(unsafe.Pointer(store))
// Initialize store value, e.g., setting initial progress to 0
atomic.StoreInt32(store, 0)
params.quantize_callback_data = unsafe.Pointer(store)
params.quantize_callback = (C.llama_progress_callback)(C.update_quantize_progress)
ticker := time.NewTicker(30 * time.Millisecond)
done := make(chan struct{})
defer close(done)
go func() {
defer ticker.Stop()
for {
select {
case <-ticker.C:
progressInt := atomic.LoadInt32(store)
progress := *(*float32)(unsafe.Pointer(&progressInt))
fn(api.ProgressResponse{
Status: fmt.Sprintf("quantizing model %d%%", 100*int(progress)/tensorCount),
})
case <-done:
fn(api.ProgressResponse{
Status: fmt.Sprintf("quantizing model 100%%"),
})
return
}
}
}()
if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 { if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
return errors.New("failed to quantize model. This model architecture may not be supported, or you may need to upgrade Ollama to the latest version") return errors.New("failed to quantize model. This model architecture may not be supported, or you may need to upgrade Ollama to the latest version")
} }

View File

@@ -0,0 +1,52 @@
From ed941590d59fc07b1ad21d6aa458588e47d1e446 Mon Sep 17 00:00:00 2001
From: Josh Yan <jyan00017@gmail.com>
Date: Wed, 10 Jul 2024 13:39:39 -0700
Subject: [PATCH] quantize progress
---
include/llama.h | 3 +++
src/llama.cpp | 8 ++++++++
2 files changed, 11 insertions(+)
diff --git a/include/llama.h b/include/llama.h
index bb4b05ba..613db68e 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -349,6 +349,9 @@ extern "C" {
bool keep_split; // quantize to the same number of shards
void * imatrix; // pointer to importance matrix data
void * kv_overrides; // pointer to vector containing overrides
+
+ llama_progress_callback quantize_callback; // callback to report quantization progress
+ void * quantize_callback_data; // user data for the callback
} llama_model_quantize_params;
// grammar types
diff --git a/src/llama.cpp b/src/llama.cpp
index 2b9ace28..ac640c02 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -18252,6 +18252,12 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
const auto tn = LLM_TN(model.arch);
new_ofstream(0);
for (int i = 0; i < ml.n_tensors; ++i) {
+ if (params->quantize_callback){
+ if (!params->quantize_callback(i, params->quantize_callback_data)) {
+ return;
+ }
+ }
+
auto weight = ml.get_weight(i);
struct ggml_tensor * tensor = weight->tensor;
if (weight->idx != cur_split && params->keep_split) {
@@ -18789,6 +18795,8 @@ struct llama_model_quantize_params llama_model_quantize_default_params() {
/*.keep_split =*/ false,
/*.imatrix =*/ nullptr,
/*.kv_overrides =*/ nullptr,
+ /*.quantize_callback =*/ nullptr,
+ /*.quantize_callback_data =*/ nullptr,
};
return result;
--
2.39.3 (Apple Git-146)

View File

@@ -122,8 +122,8 @@ function buildOllama() {
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} ollama.exe /csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} ollama.exe
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
} }
New-Item -ItemType Directory -Path .\dist\windows-${script:TARGET_ARCH}\bin\ -Force New-Item -ItemType Directory -Path .\dist\windows-${script:TARGET_ARCH}\ -Force
cp .\ollama.exe .\dist\windows-${script:TARGET_ARCH}\bin\ cp .\ollama.exe .\dist\windows-${script:TARGET_ARCH}\
} }
function buildApp() { function buildApp() {

View File

@@ -435,11 +435,14 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
return err return err
} }
tensorCount := len(baseLayer.GGML.Tensors().Items)
ft := baseLayer.GGML.KV().FileType() ft := baseLayer.GGML.KV().FileType()
if !slices.Contains([]string{"F16", "F32"}, ft.String()) { if !slices.Contains([]string{"F16", "F32"}, ft.String()) {
return errors.New("quantization is only supported for F16 and F32 models") return errors.New("quantization is only supported for F16 and F32 models")
} else if want != ft { } else if want != ft {
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", ft, quantization)}) fn(api.ProgressResponse{
Status: "quantizing model tensors",
})
blob, err := GetBlobsPath(baseLayer.Digest) blob, err := GetBlobsPath(baseLayer.Digest)
if err != nil { if err != nil {
@@ -453,7 +456,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
defer temp.Close() defer temp.Close()
defer os.Remove(temp.Name()) defer os.Remove(temp.Name())
if err := llm.Quantize(blob, temp.Name(), want); err != nil { if err := llm.Quantize(blob, temp.Name(), want, fn, tensorCount); err != nil {
return err return err
} }

View File

@@ -463,7 +463,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
c.JSON(http.StatusOK, resp) c.JSON(http.StatusOK, resp)
} }
func (s *Server) PullModelHandler(c *gin.Context) { func (s *Server) PullHandler(c *gin.Context) {
var req api.PullRequest var req api.PullRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
switch { switch {
@@ -513,7 +513,7 @@ func (s *Server) PullModelHandler(c *gin.Context) {
streamResponse(c, ch) streamResponse(c, ch)
} }
func (s *Server) PushModelHandler(c *gin.Context) { func (s *Server) PushHandler(c *gin.Context) {
var req api.PushRequest var req api.PushRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
switch { switch {
@@ -577,7 +577,7 @@ func checkNameExists(name model.Name) error {
return nil return nil
} }
func (s *Server) CreateModelHandler(c *gin.Context) { func (s *Server) CreateHandler(c *gin.Context) {
var r api.CreateRequest var r api.CreateRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) { if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
@@ -647,7 +647,7 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
streamResponse(c, ch) streamResponse(c, ch)
} }
func (s *Server) DeleteModelHandler(c *gin.Context) { func (s *Server) DeleteHandler(c *gin.Context) {
var r api.DeleteRequest var r api.DeleteRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) { if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
@@ -680,7 +680,7 @@ func (s *Server) DeleteModelHandler(c *gin.Context) {
} }
} }
func (s *Server) ShowModelHandler(c *gin.Context) { func (s *Server) ShowHandler(c *gin.Context) {
var req api.ShowRequest var req api.ShowRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
switch { switch {
@@ -829,7 +829,7 @@ func getKVData(digest string, verbose bool) (llm.KV, error) {
return kv, nil return kv, nil
} }
func (s *Server) ListModelsHandler(c *gin.Context) { func (s *Server) ListHandler(c *gin.Context) {
ms, err := Manifests() ms, err := Manifests()
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -879,7 +879,7 @@ func (s *Server) ListModelsHandler(c *gin.Context) {
c.JSON(http.StatusOK, api.ListResponse{Models: models}) c.JSON(http.StatusOK, api.ListResponse{Models: models})
} }
func (s *Server) CopyModelHandler(c *gin.Context) { func (s *Server) CopyHandler(c *gin.Context) {
var r api.CopyRequest var r api.CopyRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) { if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
@@ -1081,33 +1081,33 @@ func (s *Server) GenerateRoutes() http.Handler {
allowedHostsMiddleware(s.addr), allowedHostsMiddleware(s.addr),
) )
r.POST("/api/pull", s.PullModelHandler) r.POST("/api/pull", s.PullHandler)
r.POST("/api/generate", s.GenerateHandler) r.POST("/api/generate", s.GenerateHandler)
r.POST("/api/chat", s.ChatHandler) r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embed", s.EmbedHandler) r.POST("/api/embed", s.EmbedHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler) r.POST("/api/embeddings", s.EmbeddingsHandler)
r.POST("/api/create", s.CreateModelHandler) r.POST("/api/create", s.CreateHandler)
r.POST("/api/push", s.PushModelHandler) r.POST("/api/push", s.PushHandler)
r.POST("/api/copy", s.CopyModelHandler) r.POST("/api/copy", s.CopyHandler)
r.DELETE("/api/delete", s.DeleteModelHandler) r.DELETE("/api/delete", s.DeleteHandler)
r.POST("/api/show", s.ShowModelHandler) r.POST("/api/show", s.ShowHandler)
r.POST("/api/blobs/:digest", s.CreateBlobHandler) r.POST("/api/blobs/:digest", s.CreateBlobHandler)
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler) r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
r.GET("/api/ps", s.ProcessHandler) r.GET("/api/ps", s.PsHandler)
// Compatibility endpoints // Compatibility endpoints
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler) r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler) r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler) r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler) r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler) r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler)
for _, method := range []string{http.MethodGet, http.MethodHead} { for _, method := range []string{http.MethodGet, http.MethodHead} {
r.Handle(method, "/", func(c *gin.Context) { r.Handle(method, "/", func(c *gin.Context) {
c.String(http.StatusOK, "Ollama is running") c.String(http.StatusOK, "Ollama is running")
}) })
r.Handle(method, "/api/tags", s.ListModelsHandler) r.Handle(method, "/api/tags", s.ListHandler)
r.Handle(method, "/api/version", func(c *gin.Context) { r.Handle(method, "/api/version", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"version": version.Version}) c.JSON(http.StatusOK, gin.H{"version": version.Version})
}) })
@@ -1269,7 +1269,7 @@ func streamResponse(c *gin.Context, ch chan any) {
}) })
} }
func (s *Server) ProcessHandler(c *gin.Context) { func (s *Server) PsHandler(c *gin.Context) {
models := []api.ProcessModelResponse{} models := []api.ProcessModelResponse{}
for _, v := range s.sched.loaded { for _, v := range s.sched.loaded {

View File

@@ -93,7 +93,7 @@ func TestCreateFromBin(t *testing.T) {
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
var s Server var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test", Name: "test",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)), Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
Stream: &stream, Stream: &stream,
@@ -120,7 +120,7 @@ func TestCreateFromModel(t *testing.T) {
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
var s Server var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test", Name: "test",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)), Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
Stream: &stream, Stream: &stream,
@@ -134,7 +134,7 @@ func TestCreateFromModel(t *testing.T) {
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
}) })
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ w = createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test2", Name: "test2",
Modelfile: "FROM test", Modelfile: "FROM test",
Stream: &stream, Stream: &stream,
@@ -162,7 +162,7 @@ func TestCreateRemovesLayers(t *testing.T) {
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
var s Server var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test", Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt }}", createBinFile(t, nil, nil)), Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt }}", createBinFile(t, nil, nil)),
Stream: &stream, Stream: &stream,
@@ -182,7 +182,7 @@ func TestCreateRemovesLayers(t *testing.T) {
filepath.Join(p, "blobs", "sha256-bc80b03733773e0728011b2f4adf34c458b400e1aad48cb28d61170f3a2ad2d6"), filepath.Join(p, "blobs", "sha256-bc80b03733773e0728011b2f4adf34c458b400e1aad48cb28d61170f3a2ad2d6"),
}) })
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ w = createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test", Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t, nil, nil)), Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t, nil, nil)),
Stream: &stream, Stream: &stream,
@@ -210,7 +210,7 @@ func TestCreateUnsetsSystem(t *testing.T) {
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
var s Server var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test", Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nSYSTEM Say hi!", createBinFile(t, nil, nil)), Modelfile: fmt.Sprintf("FROM %s\nSYSTEM Say hi!", createBinFile(t, nil, nil)),
Stream: &stream, Stream: &stream,
@@ -230,7 +230,7 @@ func TestCreateUnsetsSystem(t *testing.T) {
filepath.Join(p, "blobs", "sha256-f29e82a8284dbdf5910b1555580ff60b04238b8da9d5e51159ada67a4d0d5851"), filepath.Join(p, "blobs", "sha256-f29e82a8284dbdf5910b1555580ff60b04238b8da9d5e51159ada67a4d0d5851"),
}) })
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ w = createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test", Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nSYSTEM \"\"", createBinFile(t, nil, nil)), Modelfile: fmt.Sprintf("FROM %s\nSYSTEM \"\"", createBinFile(t, nil, nil)),
Stream: &stream, Stream: &stream,
@@ -267,7 +267,7 @@ func TestCreateMergeParameters(t *testing.T) {
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
var s Server var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test", Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nPARAMETER temperature 1\nPARAMETER top_k 10\nPARAMETER stop USER:\nPARAMETER stop ASSISTANT:", createBinFile(t, nil, nil)), Modelfile: fmt.Sprintf("FROM %s\nPARAMETER temperature 1\nPARAMETER top_k 10\nPARAMETER stop USER:\nPARAMETER stop ASSISTANT:", createBinFile(t, nil, nil)),
Stream: &stream, Stream: &stream,
@@ -288,7 +288,7 @@ func TestCreateMergeParameters(t *testing.T) {
}) })
// in order to merge parameters, the second model must be created FROM the first // in order to merge parameters, the second model must be created FROM the first
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ w = createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test2", Name: "test2",
Modelfile: "FROM test\nPARAMETER temperature 0.6\nPARAMETER top_p 0.7", Modelfile: "FROM test\nPARAMETER temperature 0.6\nPARAMETER top_p 0.7",
Stream: &stream, Stream: &stream,
@@ -326,7 +326,7 @@ func TestCreateMergeParameters(t *testing.T) {
} }
// slices are replaced // slices are replaced
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ w = createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test2", Name: "test2",
Modelfile: "FROM test\nPARAMETER temperature 0.6\nPARAMETER top_p 0.7\nPARAMETER stop <|endoftext|>", Modelfile: "FROM test\nPARAMETER temperature 0.6\nPARAMETER top_p 0.7\nPARAMETER stop <|endoftext|>",
Stream: &stream, Stream: &stream,
@@ -371,7 +371,7 @@ func TestCreateReplacesMessages(t *testing.T) {
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
var s Server var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test", Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nMESSAGE assistant \"What is my purpose?\"\nMESSAGE user \"You run tests.\"\nMESSAGE assistant \"Oh, my god.\"", createBinFile(t, nil, nil)), Modelfile: fmt.Sprintf("FROM %s\nMESSAGE assistant \"What is my purpose?\"\nMESSAGE user \"You run tests.\"\nMESSAGE assistant \"Oh, my god.\"", createBinFile(t, nil, nil)),
Stream: &stream, Stream: &stream,
@@ -391,7 +391,7 @@ func TestCreateReplacesMessages(t *testing.T) {
filepath.Join(p, "blobs", "sha256-e0e27d47045063ccb167ae852c51d49a98eab33fabaee4633fdddf97213e40b5"), filepath.Join(p, "blobs", "sha256-e0e27d47045063ccb167ae852c51d49a98eab33fabaee4633fdddf97213e40b5"),
}) })
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ w = createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test2", Name: "test2",
Modelfile: "FROM test\nMESSAGE assistant \"You're a test, Harry.\"\nMESSAGE user \"I-I'm a what?\"\nMESSAGE assistant \"A test. And a thumping good one at that, I'd wager.\"", Modelfile: "FROM test\nMESSAGE assistant \"You're a test, Harry.\"\nMESSAGE user \"I-I'm a what?\"\nMESSAGE assistant \"A test. And a thumping good one at that, I'd wager.\"",
Stream: &stream, Stream: &stream,
@@ -448,7 +448,7 @@ func TestCreateTemplateSystem(t *testing.T) {
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
var s Server var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test", Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt }}\nSYSTEM Say hello!\nTEMPLATE {{ .System }} {{ .Prompt }}\nSYSTEM Say bye!", createBinFile(t, nil, nil)), Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt }}\nSYSTEM Say hello!\nTEMPLATE {{ .System }} {{ .Prompt }}\nSYSTEM Say bye!", createBinFile(t, nil, nil)),
Stream: &stream, Stream: &stream,
@@ -488,7 +488,7 @@ func TestCreateTemplateSystem(t *testing.T) {
} }
t.Run("incomplete template", func(t *testing.T) { t.Run("incomplete template", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test", Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt", createBinFile(t, nil, nil)), Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt", createBinFile(t, nil, nil)),
Stream: &stream, Stream: &stream,
@@ -500,7 +500,7 @@ func TestCreateTemplateSystem(t *testing.T) {
}) })
t.Run("template with unclosed if", func(t *testing.T) { t.Run("template with unclosed if", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test", Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ if .Prompt }}", createBinFile(t, nil, nil)), Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ if .Prompt }}", createBinFile(t, nil, nil)),
Stream: &stream, Stream: &stream,
@@ -512,7 +512,7 @@ func TestCreateTemplateSystem(t *testing.T) {
}) })
t.Run("template with undefined function", func(t *testing.T) { t.Run("template with undefined function", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test", Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ Prompt }}", createBinFile(t, nil, nil)), Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ Prompt }}", createBinFile(t, nil, nil)),
Stream: &stream, Stream: &stream,
@@ -531,7 +531,7 @@ func TestCreateLicenses(t *testing.T) {
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
var s Server var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test", Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nLICENSE MIT\nLICENSE Apache-2.0", createBinFile(t, nil, nil)), Modelfile: fmt.Sprintf("FROM %s\nLICENSE MIT\nLICENSE Apache-2.0", createBinFile(t, nil, nil)),
Stream: &stream, Stream: &stream,
@@ -579,7 +579,7 @@ func TestCreateDetectTemplate(t *testing.T) {
var s Server var s Server
t.Run("matched", func(t *testing.T) { t.Run("matched", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test", Name: "test",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{ Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
"tokenizer.chat_template": "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}", "tokenizer.chat_template": "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
@@ -600,7 +600,7 @@ func TestCreateDetectTemplate(t *testing.T) {
}) })
t.Run("unmatched", func(t *testing.T) { t.Run("unmatched", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test", Name: "test",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)), Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
Stream: &stream, Stream: &stream,

View File

@@ -22,7 +22,7 @@ func TestDelete(t *testing.T) {
var s Server var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test", Name: "test",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)), Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
}) })
@@ -31,7 +31,7 @@ func TestDelete(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code) t.Fatalf("expected status code 200, actual %d", w.Code)
} }
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ w = createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test2", Name: "test2",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t, nil, nil)), Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t, nil, nil)),
}) })
@@ -52,7 +52,7 @@ func TestDelete(t *testing.T) {
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"), filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
}) })
w = createRequest(t, s.DeleteModelHandler, api.DeleteRequest{Name: "test"}) w = createRequest(t, s.DeleteHandler, api.DeleteRequest{Name: "test"})
if w.Code != http.StatusOK { if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code) t.Fatalf("expected status code 200, actual %d", w.Code)
@@ -68,7 +68,7 @@ func TestDelete(t *testing.T) {
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"), filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
}) })
w = createRequest(t, s.DeleteModelHandler, api.DeleteRequest{Name: "test2"}) w = createRequest(t, s.DeleteHandler, api.DeleteRequest{Name: "test2"})
if w.Code != http.StatusOK { if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code) t.Fatalf("expected status code 200, actual %d", w.Code)
@@ -102,7 +102,7 @@ func TestDeleteDuplicateLayers(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
w := createRequest(t, s.DeleteModelHandler, api.DeleteRequest{Name: "test"}) w := createRequest(t, s.DeleteHandler, api.DeleteRequest{Name: "test"})
if w.Code != http.StatusOK { if w.Code != http.StatusOK {
t.Errorf("expected status code 200, actual %d", w.Code) t.Errorf("expected status code 200, actual %d", w.Code)
} }

View File

@@ -84,7 +84,7 @@ func TestGenerateChat(t *testing.T) {
go s.sched.Run(context.TODO()) go s.sched.Run(context.TODO())
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test", Model: "test",
Modelfile: fmt.Sprintf(`FROM %s Modelfile: fmt.Sprintf(`FROM %s
TEMPLATE """ TEMPLATE """
@@ -144,7 +144,7 @@ func TestGenerateChat(t *testing.T) {
}) })
t.Run("missing capabilities chat", func(t *testing.T) { t.Run("missing capabilities chat", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "bert", Model: "bert",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{ Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
"general.architecture": "bert", "general.architecture": "bert",
@@ -270,7 +270,7 @@ func TestGenerateChat(t *testing.T) {
checkChatResponse(t, w.Body, "test", "Hi!") checkChatResponse(t, w.Body, "test", "Hi!")
}) })
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ w = createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-system", Model: "test-system",
Modelfile: "FROM test\nSYSTEM You are a helpful assistant.", Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
}) })
@@ -382,7 +382,7 @@ func TestGenerate(t *testing.T) {
go s.sched.Run(context.TODO()) go s.sched.Run(context.TODO())
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test", Model: "test",
Modelfile: fmt.Sprintf(`FROM %s Modelfile: fmt.Sprintf(`FROM %s
TEMPLATE """ TEMPLATE """
@@ -442,7 +442,7 @@ func TestGenerate(t *testing.T) {
}) })
t.Run("missing capabilities generate", func(t *testing.T) { t.Run("missing capabilities generate", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "bert", Model: "bert",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{ Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
"general.architecture": "bert", "general.architecture": "bert",
@@ -583,7 +583,7 @@ func TestGenerate(t *testing.T) {
checkGenerateResponse(t, w.Body, "test", "Hi!") checkGenerateResponse(t, w.Body, "test", "Hi!")
}) })
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ w = createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-system", Model: "test-system",
Modelfile: "FROM test\nSYSTEM You are a helpful assistant.", Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
}) })
@@ -652,7 +652,7 @@ func TestGenerate(t *testing.T) {
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!") checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
}) })
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ w = createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-suffix", Model: "test-suffix",
Modelfile: `FROM test Modelfile: `FROM test
TEMPLATE """{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID> TEMPLATE """{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>

View File

@@ -31,13 +31,13 @@ func TestList(t *testing.T) {
var s Server var s Server
for _, n := range expectNames { for _, n := range expectNames {
createRequest(t, s.CreateModelHandler, api.CreateRequest{ createRequest(t, s.CreateHandler, api.CreateRequest{
Name: n, Name: n,
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)), Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
}) })
} }
w := createRequest(t, s.ListModelsHandler, nil) w := createRequest(t, s.ListHandler, nil)
if w.Code != http.StatusOK { if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code) t.Fatalf("expected status code 200, actual %d", w.Code)
} }

View File

@@ -318,7 +318,7 @@ func TestCase(t *testing.T) {
var s Server var s Server
for _, tt := range cases { for _, tt := range cases {
t.Run(tt, func(t *testing.T) { t.Run(tt, func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: tt, Name: tt,
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)), Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
Stream: &stream, Stream: &stream,
@@ -334,7 +334,7 @@ func TestCase(t *testing.T) {
} }
t.Run("create", func(t *testing.T) { t.Run("create", func(t *testing.T) {
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ w = createRequest(t, s.CreateHandler, api.CreateRequest{
Name: strings.ToUpper(tt), Name: strings.ToUpper(tt),
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)), Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
Stream: &stream, Stream: &stream,
@@ -350,7 +350,7 @@ func TestCase(t *testing.T) {
}) })
t.Run("pull", func(t *testing.T) { t.Run("pull", func(t *testing.T) {
w := createRequest(t, s.PullModelHandler, api.PullRequest{ w := createRequest(t, s.PullHandler, api.PullRequest{
Name: strings.ToUpper(tt), Name: strings.ToUpper(tt),
Stream: &stream, Stream: &stream,
}) })
@@ -365,7 +365,7 @@ func TestCase(t *testing.T) {
}) })
t.Run("copy", func(t *testing.T) { t.Run("copy", func(t *testing.T) {
w := createRequest(t, s.CopyModelHandler, api.CopyRequest{ w := createRequest(t, s.CopyHandler, api.CopyRequest{
Source: tt, Source: tt,
Destination: strings.ToUpper(tt), Destination: strings.ToUpper(tt),
}) })
@@ -387,7 +387,7 @@ func TestShow(t *testing.T) {
var s Server var s Server
createRequest(t, s.CreateModelHandler, api.CreateRequest{ createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "show-model", Name: "show-model",
Modelfile: fmt.Sprintf( Modelfile: fmt.Sprintf(
"FROM %s\nFROM %s", "FROM %s\nFROM %s",
@@ -396,7 +396,7 @@ func TestShow(t *testing.T) {
), ),
}) })
w := createRequest(t, s.ShowModelHandler, api.ShowRequest{ w := createRequest(t, s.ShowHandler, api.ShowRequest{
Name: "show-model", Name: "show-model",
}) })