Compare commits

..

1 Commits

Author SHA1 Message Date
Michael Yang
3803ecb6a6 cmd build context 2024-07-01 16:03:52 -07:00
169 changed files with 1865 additions and 5429 deletions

View File

@@ -147,7 +147,7 @@ jobs:
run: |
$ErrorActionPreference = "Stop"
write-host "downloading AMD HIP Installer"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
write-host "Installing AMD HIP"
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
write-host "Completed AMD HIP"

View File

@@ -58,7 +58,6 @@ jobs:
runs-on: ${{ matrix.os }}
env:
GOARCH: ${{ matrix.arch }}
CGO_ENABLED: '1'
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
@@ -80,7 +79,6 @@ jobs:
- run: go generate -x ./...
if: ${{ ! startsWith(matrix.os, 'windows-') }}
name: 'Unix Go Generate'
- run: go build .
- uses: actions/upload-artifact@v4
with:
name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
@@ -126,7 +124,7 @@ jobs:
strategy:
matrix:
rocm-version:
- '6.1.2'
- '6.1.1'
runs-on: linux
container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }}
steps:
@@ -169,7 +167,7 @@ jobs:
run: |
$ErrorActionPreference = "Stop"
write-host "downloading AMD HIP Installer"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
write-host "Installing AMD HIP"
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
write-host "Completed AMD HIP"

View File

@@ -2,7 +2,7 @@ ARG GOLANG_VERSION=1.22.1
ARG CMAKE_VERSION=3.22.1
# this CUDA_VERSION corresponds with the one specified in docs/gpu.md
ARG CUDA_VERSION=11.3.1
ARG ROCM_VERSION=6.1.2
ARG ROCM_VERSION=6.1.1
# Copy the minimal context we need to run the generate scripts
FROM scratch AS llm-code
@@ -70,12 +70,12 @@ RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
FROM --platform=linux/arm64 rockylinux:8 AS cpu-builder-arm64
FROM --platform=linux/arm64 centos:7 AS cpu-builder-arm64
ARG CMAKE_VERSION
ARG GOLANG_VERSION
COPY ./scripts/rh_linux_deps.sh /
RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
COPY --from=llm-code / /go/src/github.com/ollama/ollama/
ARG OLLAMA_CUSTOM_CPU_DEFS
ARG CGO_CFLAGS

View File

@@ -293,8 +293,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
- [AI Studio](https://github.com/MindWorkAI/AI-Studio)
### Terminal

View File

@@ -347,16 +347,7 @@ func (c *Client) Heartbeat(ctx context.Context) error {
return nil
}
// Embed generates embeddings from a model.
func (c *Client) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) {
var resp EmbedResponse
if err := c.do(ctx, http.MethodPost, "/api/embed", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// Embeddings generates an embedding from a model.
// Embeddings generates embeddings from a model.
func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
var resp EmbeddingResponse
if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {

View File

@@ -47,9 +47,6 @@ type GenerateRequest struct {
// Prompt is the textual prompt to send to the model.
Prompt string `json:"prompt"`
// Suffix is the text that comes after the inserted text.
Suffix string `json:"suffix"`
// System overrides the model's default system message/prompt.
System string `json:"system"`
@@ -100,9 +97,6 @@ type ChatRequest struct {
// followin the request.
KeepAlive *Duration `json:"keep_alive,omitempty"`
// Tools is an optional list of tools the model has access to.
Tools []Tool `json:"tools,omitempty"`
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
}
@@ -111,46 +105,9 @@ type ChatRequest struct {
// role ("system", "user", or "assistant"), the content and an optional list
// of images.
type Message struct {
Role string `json:"role"`
Content string `json:"content,omitempty"`
Images []ImageData `json:"images,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
type ToolCall struct {
Function struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
} `json:"function"`
}
type Tool struct {
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters struct {
Type string `json:"type"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
} `json:"parameters"`
} `json:"function"`
}
func (m *Message) UnmarshalJSON(b []byte) error {
type Alias Message
var a Alias
if err := json.Unmarshal(b, &a); err != nil {
return err
}
*m = Message(a)
m.Role = strings.ToLower(m.Role)
return nil
Role string `json:"role"`
Content string `json:"content"`
Images []ImageData `json:"images,omitempty"`
}
// ChatResponse is the response returned by [Client.Chat]. Its fields are
@@ -202,42 +159,49 @@ type Options struct {
// Runner options which must be set when the model is loaded into memory
type Runner struct {
UseNUMA bool `json:"numa,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
NumBatch int `json:"num_batch,omitempty"`
NumGPU int `json:"num_gpu,omitempty"`
MainGPU int `json:"main_gpu,omitempty"`
LowVRAM bool `json:"low_vram,omitempty"`
F16KV bool `json:"f16_kv,omitempty"`
LogitsAll bool `json:"logits_all,omitempty"`
VocabOnly bool `json:"vocab_only,omitempty"`
UseMMap *bool `json:"use_mmap,omitempty"`
UseMLock bool `json:"use_mlock,omitempty"`
NumThread int `json:"num_thread,omitempty"`
UseNUMA bool `json:"numa,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
NumBatch int `json:"num_batch,omitempty"`
NumGPU int `json:"num_gpu,omitempty"`
MainGPU int `json:"main_gpu,omitempty"`
LowVRAM bool `json:"low_vram,omitempty"`
F16KV bool `json:"f16_kv,omitempty"`
LogitsAll bool `json:"logits_all,omitempty"`
VocabOnly bool `json:"vocab_only,omitempty"`
UseMMap TriState `json:"use_mmap,omitempty"`
UseMLock bool `json:"use_mlock,omitempty"`
NumThread int `json:"num_thread,omitempty"`
}
// EmbedRequest is the request passed to [Client.Embed].
type EmbedRequest struct {
// Model is the model name.
Model string `json:"model"`
type TriState int
// Input is the input to embed.
Input any `json:"input"`
const (
TriStateUndefined TriState = -1
TriStateFalse TriState = 0
TriStateTrue TriState = 1
)
// KeepAlive controls how long the model will stay loaded in memory following
// this request.
KeepAlive *Duration `json:"keep_alive,omitempty"`
Truncate *bool `json:"truncate,omitempty"`
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
func (b *TriState) UnmarshalJSON(data []byte) error {
var v bool
if err := json.Unmarshal(data, &v); err != nil {
return err
}
if v {
*b = TriStateTrue
}
*b = TriStateFalse
return nil
}
// EmbedResponse is the response from [Client.Embed].
type EmbedResponse struct {
Model string `json:"model"`
Embeddings [][]float32 `json:"embeddings"`
func (b *TriState) MarshalJSON() ([]byte, error) {
if *b == TriStateUndefined {
return nil, nil
}
var v bool
if *b == TriStateTrue {
v = true
}
return json.Marshal(v)
}
// EmbeddingRequest is the request passed to [Client.Embeddings].
@@ -286,10 +250,8 @@ type DeleteRequest struct {
// ShowRequest is the request passed to [Client.Show].
type ShowRequest struct {
Model string `json:"model"`
System string `json:"system"`
// Template is deprecated
Model string `json:"model"`
System string `json:"system"`
Template string `json:"template"`
Verbose bool `json:"verbose"`
@@ -383,13 +345,6 @@ type ProcessModelResponse struct {
SizeVRAM int64 `json:"size_vram"`
}
type RetrieveModelResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
type TokenResponse struct {
Token string `json:"token"`
}
@@ -405,9 +360,6 @@ type GenerateResponse struct {
// Response is the textual response itself.
Response string `json:"response"`
// ToolCalls is the list of tools the model wants to call
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
// Done specifies if the response is complete.
Done bool `json:"done"`
@@ -485,6 +437,19 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
continue
}
if reflect.PointerTo(field.Type()) == reflect.TypeOf((*TriState)(nil)) {
val, ok := val.(bool)
if !ok {
return fmt.Errorf("option %q must be of type boolean", key)
}
if val {
field.SetInt(int64(TriStateTrue))
} else {
field.SetInt(int64(TriStateFalse))
}
continue
}
switch field.Kind() {
case reflect.Int:
switch t := val.(type) {
@@ -531,17 +496,6 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
slice[i] = str
}
field.Set(reflect.ValueOf(slice))
case reflect.Pointer:
var b bool
if field.Type() == reflect.TypeOf(&b) {
val, ok := val.(bool)
if !ok {
return fmt.Errorf("option %q must be of type boolean", key)
}
field.Set(reflect.ValueOf(&val))
} else {
return fmt.Errorf("unknown type loading config params: %v %v", field.Kind(), field.Type())
}
default:
return fmt.Errorf("unknown type loading config params: %v", field.Kind())
}
@@ -584,7 +538,7 @@ func DefaultOptions() Options {
LowVRAM: false,
F16KV: true,
UseMLock: false,
UseMMap: nil,
UseMMap: TriStateUndefined,
UseNUMA: false,
},
}
@@ -654,6 +608,19 @@ func FormatParams(params map[string][]string) (map[string]interface{}, error) {
} else {
field := valueOpts.FieldByName(opt.Name)
if field.IsValid() && field.CanSet() {
if reflect.PointerTo(field.Type()) == reflect.TypeOf((*TriState)(nil)) {
boolVal, err := strconv.ParseBool(vals[0])
if err != nil {
return nil, fmt.Errorf("invalid bool value %s", vals)
}
if boolVal {
out[key] = TriStateTrue
} else {
out[key] = TriStateFalse
}
continue
}
switch field.Kind() {
case reflect.Float32:
floatVal, err := strconv.ParseFloat(vals[0], 32)
@@ -681,17 +648,6 @@ func FormatParams(params map[string][]string) (map[string]interface{}, error) {
case reflect.Slice:
// TODO: only string slices are supported right now
out[key] = vals
case reflect.Pointer:
var b bool
if field.Type() == reflect.TypeOf(&b) {
boolVal, err := strconv.ParseBool(vals[0])
if err != nil {
return nil, fmt.Errorf("invalid bool value %s", vals)
}
out[key] = &boolVal
} else {
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
}
default:
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
}

View File

@@ -108,27 +108,25 @@ func TestDurationMarshalUnmarshal(t *testing.T) {
}
func TestUseMmapParsingFromJSON(t *testing.T) {
tr := true
fa := false
tests := []struct {
name string
req string
exp *bool
exp TriState
}{
{
name: "Undefined",
req: `{ }`,
exp: nil,
exp: TriStateUndefined,
},
{
name: "True",
req: `{ "use_mmap": true }`,
exp: &tr,
exp: TriStateTrue,
},
{
name: "False",
req: `{ "use_mmap": false }`,
exp: &fa,
exp: TriStateFalse,
},
}
@@ -146,52 +144,50 @@ func TestUseMmapParsingFromJSON(t *testing.T) {
}
func TestUseMmapFormatParams(t *testing.T) {
tr := true
fa := false
tests := []struct {
name string
req map[string][]string
exp *bool
exp TriState
err error
}{
{
name: "True",
req: map[string][]string{
"use_mmap": {"true"},
"use_mmap": []string{"true"},
},
exp: &tr,
exp: TriStateTrue,
err: nil,
},
{
name: "False",
req: map[string][]string{
"use_mmap": {"false"},
"use_mmap": []string{"false"},
},
exp: &fa,
exp: TriStateFalse,
err: nil,
},
{
name: "Numeric True",
req: map[string][]string{
"use_mmap": {"1"},
"use_mmap": []string{"1"},
},
exp: &tr,
exp: TriStateTrue,
err: nil,
},
{
name: "Numeric False",
req: map[string][]string{
"use_mmap": {"0"},
"use_mmap": []string{"0"},
},
exp: &fa,
exp: TriStateFalse,
err: nil,
},
{
name: "invalid string",
req: map[string][]string{
"use_mmap": {"foo"},
"use_mmap": []string{"foo"},
},
exp: nil,
exp: TriStateUndefined,
err: fmt.Errorf("invalid bool value [foo]"),
},
}
@@ -199,35 +195,12 @@ func TestUseMmapFormatParams(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
resp, err := FormatParams(test.req)
require.Equal(t, test.err, err)
require.Equal(t, err, test.err)
respVal, ok := resp["use_mmap"]
if test.exp != nil {
if test.exp != TriStateUndefined {
assert.True(t, ok, "resp: %v", resp)
assert.Equal(t, *test.exp, *respVal.(*bool))
assert.Equal(t, test.exp, respVal)
}
})
}
}
func TestMessage_UnmarshalJSON(t *testing.T) {
tests := []struct {
input string
expected string
}{
{`{"role": "USER", "content": "Hello!"}`, "user"},
{`{"role": "System", "content": "Initialization complete."}`, "system"},
{`{"role": "assistant", "content": "How can I help you?"}`, "assistant"},
{`{"role": "TOOl", "content": "Access granted."}`, "tool"},
}
for _, test := range tests {
var msg Message
if err := json.Unmarshal([]byte(test.input), &msg); err != nil {
t.Errorf("Unexpected error: %v", err)
}
if msg.Role != test.expected {
t.Errorf("role not lowercased: got %v, expected %v", msg.Role, test.expected)
}
}
}

View File

@@ -127,10 +127,6 @@ Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\models"
Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\history"
; NOTE: if the user has a custom OLLAMA_MODELS it will be preserved
[InstallDelete]
Type: filesandordirs; Name: "{%TEMP}\ollama*"
Type: filesandordirs; Name: "{%LOCALAPPDATA}\Programs\Ollama"
[Messages]
WizardReady=Ollama Windows Preview
ReadyLabel1=%nLet's get you up and running with your own large language models.

View File

@@ -3,6 +3,7 @@ package cmd
import (
"archive/zip"
"bytes"
"cmp"
"context"
"crypto/ed25519"
"crypto/rand"
@@ -11,6 +12,7 @@ import (
"errors"
"fmt"
"io"
"io/fs"
"log"
"math"
"net"
@@ -70,30 +72,24 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
status := "transferring model data"
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
createCtx, err := cmd.Flags().GetString("context")
if err != nil {
return err
}
createCtx = cmp.Or(createCtx, filepath.Dir(filename))
fsys := os.DirFS(createCtx)
for i := range modelfile.Commands {
switch modelfile.Commands[i].Name {
case "model", "adapter":
path := modelfile.Commands[i].Args
if path == "~" {
path = home
} else if strings.HasPrefix(path, "~/") {
path = filepath.Join(home, path[2:])
}
p := filepath.Clean(modelfile.Commands[i].Args)
if !filepath.IsAbs(path) {
path = filepath.Join(filepath.Dir(filename), path)
}
fi, err := os.Stat(path)
fi, err := fs.Stat(fsys, p)
if errors.Is(err, os.ErrNotExist) && modelfile.Commands[i].Name == "model" {
continue
} else if err != nil {
@@ -103,16 +99,29 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
if fi.IsDir() {
// this is likely a safetensors or pytorch directory
// TODO make this work w/ adapters
tempfile, err := tempZipFiles(path)
sub, err := fs.Sub(fsys, p)
if err != nil {
return err
}
defer os.RemoveAll(tempfile)
path = tempfile
temp, err := os.CreateTemp(createCtx, "*.zip")
if err != nil {
return err
}
defer temp.Close()
defer os.RemoveAll(temp.Name())
if err := zipFiles(sub, temp); err != nil {
return err
}
p, err = filepath.Rel(createCtx, temp.Name())
if err != nil {
return err
}
}
digest, err := createBlob(cmd, client, path)
digest, err := createBlob(cmd, client, fsys, p)
if err != nil {
return err
}
@@ -155,42 +164,34 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return nil
}
func tempZipFiles(path string) (string, error) {
tempfile, err := os.CreateTemp("", "ollama-tf")
if err != nil {
return "", err
}
defer tempfile.Close()
detectContentType := func(path string) (string, error) {
f, err := os.Open(path)
func zipFiles(fsys fs.FS, w io.Writer) error {
detectContentType := func(name string) (string, error) {
f, err := fsys.Open(name)
if err != nil {
return "", err
}
defer f.Close()
var b bytes.Buffer
b.Grow(512)
if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
bts, err := io.ReadAll(io.LimitReader(f, 512))
if err != nil {
return "", err
}
contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
contentType, _, _ := strings.Cut(http.DetectContentType(bts), ";")
return contentType, nil
}
glob := func(pattern, contentType string) ([]string, error) {
matches, err := filepath.Glob(pattern)
matches, err := fs.Glob(fsys, pattern)
if err != nil {
return nil, err
}
for _, safetensor := range matches {
if ct, err := detectContentType(safetensor); err != nil {
for _, match := range matches {
if ct, err := detectContentType(match); err != nil {
return nil, err
} else if ct != contentType {
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor)
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match)
}
}
@@ -198,73 +199,73 @@ func tempZipFiles(path string) (string, error) {
}
var files []string
if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 {
if st, _ := glob("model*.safetensors", "application/octet-stream"); len(st) > 0 {
// safetensors files might be unresolved git lfs references; skip if they are
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
files = append(files, st...)
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
} else if pt, _ := glob("pytorch_model*.bin", "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
files = append(files, pt...)
} else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/zip"); len(pt) > 0 {
} else if pt, _ := glob("consolidated*.pth", "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are
// covers consolidated.x.pth, consolidated.pth
files = append(files, pt...)
} else {
return "", errors.New("no safetensors or torch files found")
return errors.New("no safetensors or torch files found")
}
// add configuration files, json files are detected as text/plain
js, err := glob(filepath.Join(path, "*.json"), "text/plain")
js, err := glob("*.json", "text/plain")
if err != nil {
return "", err
return err
}
files = append(files, js...)
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
if tks, _ := glob("tokenizer.model", "application/octet-stream"); len(tks) > 0 {
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
// tokenizer.model might be a unresolved git lfs reference; error if it is
files = append(files, tks...)
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
} else if tks, _ := glob("**/tokenizer.model", "text/plain"); len(tks) > 0 {
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
files = append(files, tks...)
}
zipfile := zip.NewWriter(tempfile)
zipfile := zip.NewWriter(w)
defer zipfile.Close()
for _, file := range files {
f, err := os.Open(file)
f, err := fsys.Open(file)
if err != nil {
return "", err
return err
}
defer f.Close()
fi, err := f.Stat()
if err != nil {
return "", err
return err
}
zfi, err := zip.FileInfoHeader(fi)
if err != nil {
return "", err
return err
}
zf, err := zipfile.CreateHeader(zfi)
if err != nil {
return "", err
return err
}
if _, err := io.Copy(zf, f); err != nil {
return "", err
return err
}
}
return tempfile.Name(), nil
return nil
}
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
bin, err := os.Open(path)
func sha256sum(fsys fs.FS, name string) (string, error) {
bin, err := fsys.Open(name)
if err != nil {
return "", err
}
@@ -275,14 +276,25 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
return "", err
}
if _, err := bin.Seek(0, io.SeekStart); err != nil {
return fmt.Sprintf("sha256:%x", hash.Sum(nil)), nil
}
func createBlob(cmd *cobra.Command, client *api.Client, fsys fs.FS, name string) (string, error) {
bin, err := fsys.Open(name)
if err != nil {
return "", err
}
defer bin.Close()
digest, err := sha256sum(fsys, name)
if err != nil {
return "", err
}
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
return "", err
}
return digest, nil
}
@@ -843,6 +855,7 @@ type runOptions struct {
WordWrap bool
Format string
System string
Template string
Images []api.ImageData
Options map[string]interface{}
MultiModal bool
@@ -1036,6 +1049,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
Images: opts.Images,
Format: opts.Format,
System: opts.System,
Template: opts.Template,
Options: opts.Options,
KeepAlive: opts.KeepAlive,
}
@@ -1224,6 +1238,7 @@ func NewCLI() *cobra.Command {
createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile")
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_0)")
createCmd.Flags().StringP("context", "C", "", "Context for the model")
showCmd := &cobra.Command{
Use: "show MODEL",

View File

@@ -27,6 +27,7 @@ const (
MultilineNone MultilineState = iota
MultilinePrompt
MultilineSystem
MultilineTemplate
)
func loadModel(cmd *cobra.Command, opts *runOptions) error {
@@ -93,6 +94,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter")
fmt.Fprintln(os.Stderr, " /set system <string> Set system message")
fmt.Fprintln(os.Stderr, " /set template <string> Set prompt template")
fmt.Fprintln(os.Stderr, " /set history Enable history")
fmt.Fprintln(os.Stderr, " /set nohistory Disable history")
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
@@ -202,6 +204,10 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
fmt.Println("Set system message.")
sb.Reset()
case MultilineTemplate:
opts.Template = sb.String()
fmt.Println("Set prompt template.")
sb.Reset()
}
multiline = MultilineNone
@@ -320,13 +326,17 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
}
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
opts.Options[args[2]] = fp[args[2]]
case "system":
case "system", "template":
if len(args) < 3 {
usageSet()
continue
}
multiline = MultilineSystem
if args[1] == "system" {
multiline = MultilineSystem
} else if args[1] == "template" {
multiline = MultilineTemplate
}
line := strings.Join(args[2:], " ")
line, ok := strings.CutPrefix(line, `"""`)
@@ -346,17 +356,23 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
continue
}
opts.System = sb.String() // for display in modelfile
newMessage := api.Message{Role: "system", Content: sb.String()}
// Check if the slice is not empty and the last message is from 'system'
if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" {
// Replace the last message
opts.Messages[len(opts.Messages)-1] = newMessage
} else {
opts.Messages = append(opts.Messages, newMessage)
if args[1] == "system" {
opts.System = sb.String() // for display in modelfile
newMessage := api.Message{Role: "system", Content: sb.String()}
// Check if the slice is not empty and the last message is from 'system'
if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" {
// Replace the last message
opts.Messages[len(opts.Messages)-1] = newMessage
} else {
opts.Messages = append(opts.Messages, newMessage)
}
fmt.Println("Set system message.")
sb.Reset()
} else if args[1] == "template" {
opts.Template = sb.String()
fmt.Println("Set prompt template.")
sb.Reset()
}
fmt.Println("Set system message.")
sb.Reset()
sb.Reset()
continue
@@ -377,6 +393,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
req := &api.ShowRequest{
Name: opts.Model,
System: opts.System,
Template: opts.Template,
Options: opts.Options,
}
resp, err := client.Show(cmd.Context(), req)
@@ -420,9 +437,12 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Println("No system message was specified for this model.")
}
case "template":
if resp.Template != "" {
switch {
case opts.Template != "":
fmt.Println(opts.Template + "\n")
case resp.Template != "":
fmt.Println(resp.Template)
} else {
default:
fmt.Println("No prompt template was specified for this model.")
}
default:
@@ -516,6 +536,10 @@ func buildModelfile(opts runOptions) string {
fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System)
}
if opts.Template != "" {
fmt.Fprintf(&mf, "TEMPLATE \"\"\"%s\"\"\"\n", opts.Template)
}
keys := make([]string, 0)
for k := range opts.Options {
keys = append(keys, k)

View File

@@ -59,6 +59,7 @@ func TestModelfileBuilder(t *testing.T) {
opts := runOptions{
Model: "hork",
System: "You are part horse and part shark, but all hork. Do horklike things",
Template: "This is a template.",
Messages: []api.Message{
{Role: "user", Content: "Hey there hork!"},
{Role: "assistant", Content: "Yes it is true, I am half horse, half shark."},
@@ -74,6 +75,7 @@ func TestModelfileBuilder(t *testing.T) {
mf := buildModelfile(opts)
expectedModelfile := `FROM {{.Model}}
SYSTEM """{{.System}}"""
TEMPLATE """{{.Template}}"""
PARAMETER penalize_newline false
PARAMETER seed 42
PARAMETER stop [hi there]
@@ -95,6 +97,7 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark."""
mf = buildModelfile(opts)
expectedModelfile = `FROM {{.ParentModel}}
SYSTEM """{{.System}}"""
TEMPLATE """{{.Template}}"""
PARAMETER penalize_newline false
PARAMETER seed 42
PARAMETER stop [hi there]

View File

@@ -104,7 +104,7 @@ like to use. For example, to compile an optimized binary for an Intel i9-9880H,
you might use:
```
OLLAMA_CUSTOM_CPU_DEFS="-DGGML_AVX=on -DGGML_AVX2=on -DGGML_F16C=on -DGGML_FMA=on" go generate ./...
OLLAMA_CUSTOM_CPU_DEFS="-DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_F16C=on -DLLAMA_FMA=on" go generate ./...
go build .
```

View File

@@ -266,10 +266,8 @@ If there is insufficient available memory to load a new model request while one
Parallel request processing for a given model results in increasing the context size by the number of parallel requests. For example, a 2K context with 4 parallel requests will result in an 8K context and additional memory allocation.
The following server settings may be used to adjust how Ollama handles concurrent requests on most platforms:
The following server settings may be used to adjust how Ollama handles concurrent requests:
- `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory. The default is 3 * the number of GPUs or 3 for CPU inference.
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.

View File

@@ -27,11 +27,6 @@ chat_completion = client.chat.completions.create(
],
model='llama3',
)
completion = client.completions.create(
model="llama3",
prompt="Say this is a test"
)
```
### OpenAI JavaScript library
@@ -50,11 +45,6 @@ const chatCompletion = await openai.chat.completions.create({
messages: [{ role: 'user', content: 'Say this is a test' }],
model: 'llama3',
})
const completion = await openai.completions.create({
model: "llama3",
prompt: "Say this is a test.",
})
```
### `curl`
@@ -75,13 +65,6 @@ curl http://localhost:11434/v1/chat/completions \
}
]
}'
curl http://localhost:11434/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llama3",
"prompt": "Say this is a test"
}'
```
## Endpoints
@@ -119,71 +102,8 @@ curl http://localhost:11434/v1/completions \
- [ ] `user`
- [ ] `n`
### `/v1/completions`
#### Supported features
- [x] Completions
- [x] Streaming
- [x] JSON mode
- [x] Reproducible outputs
- [ ] Logprobs
#### Supported request fields
- [x] `model`
- [x] `prompt`
- [x] `frequency_penalty`
- [x] `presence_penalty`
- [x] `seed`
- [x] `stop`
- [x] `stream`
- [x] `temperature`
- [x] `top_p`
- [x] `max_tokens`
- [x] `suffix`
- [ ] `best_of`
- [ ] `echo`
- [ ] `logit_bias`
- [ ] `user`
- [ ] `n`
#### Notes
- `prompt` currently only accepts a string
### `/v1/completions`
#### Supported features
- [x] Completions
- [x] Streaming
- [x] JSON mode
- [x] Reproducible outputs
- [ ] Logprobs
#### Supported request fields
- [x] `model`
- [x] `prompt`
- [x] `frequency_penalty`
- [x] `presence_penalty`
- [x] `seed`
- [x] `stop`
- [x] `stream`
- [x] `temperature`
- [x] `top_p`
- [x] `max_tokens`
- [ ] `best_of`
- [ ] `echo`
- [ ] `suffix`
- [ ] `logit_bias`
- [ ] `user`
- [ ] `n`
#### Notes
- `prompt` currently only accepts a string
- `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached
## Models

View File

@@ -70,18 +70,14 @@ curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION="0.1.29" sh
If your system is configured with the "noexec" flag where Ollama stores its temporary executable files, you can specify an alternate location by setting OLLAMA_TMPDIR to a location writable by the user ollama runs as. For example OLLAMA_TMPDIR=/usr/share/ollama/
## NVIDIA GPU Discovery
## Container fails to run on NVIDIA GPU
When Ollama starts up, it takes inventory of the GPUs present in the system to determine compatibility and how much VRAM is available. Sometimes this discovery can fail to find your GPUs. In general, running the latest driver will yield the best results.
Make sure you've set up the container runtime first as described in [docker.md](./docker.md)
### Linux NVIDIA Troubleshooting
Sometimes the container runtime can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem
If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker.md](./docker.md)
Sometimes the Ollama can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem
- If you are using a container, is the container runtime working? Try `docker run --gpus all ubuntu nvidia-smi` - if this doesn't work, Ollama wont be able to see your NVIDIA GPU.
- Is the uvm driver loaded? `sudo nvidia-modprobe -u`
- Is the container runtime working? Try `docker run --gpus all ubuntu nvidia-smi` - if this doesn't work, Ollama wont be able to see your NVIDIA GPU.
- Is the uvm driver not loaded? `sudo nvidia-modprobe -u`
- Try reloading the nvidia_uvm driver - `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm`
- Try rebooting
- Make sure you're running the latest nvidia drivers
@@ -89,8 +85,3 @@ Sometimes the Ollama can have difficulties initializing the GPU. When you check
If none of those resolve the problem, gather additional information and file an issue:
- Set `CUDA_ERROR_LEVEL=50` and try again to get more diagnostic logs
- Check dmesg for any errors `sudo dmesg | grep -i nvrm` and `sudo dmesg | grep -i nvidia`
## Windows Terminal Errors
Older versions of Windows 10 (e.g., 21H1) are known to have a bug where the standard terminal program does not display control characters correctly. This can result in a long string of strings like `←[?25h←[?25l` being displayed, sometimes erroring with `The parameter is incorrect` To resolve this problem, please update to Win 10 22H1 or newer.

View File

@@ -19,7 +19,7 @@ Logs will often be helpful in diagnosing the problem (see
## System Requirements
* Windows 10 22H2 or newer, Home or Pro
* Windows 10 or newer, Home or Pro
* NVIDIA 452.39 or newer Drivers if you have an NVIDIA card
* AMD Radeon Driver https://www.amd.com/en/support if you have a Radeon card

View File

@@ -4,14 +4,12 @@ import (
"errors"
"fmt"
"log/slog"
"math"
"net"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
)
type OllamaHost struct {
@@ -36,17 +34,17 @@ var (
// Set via OLLAMA_HOST in the environment
Host *OllamaHost
// Set via OLLAMA_KEEP_ALIVE in the environment
KeepAlive time.Duration
KeepAlive string
// Set via OLLAMA_LLM_LIBRARY in the environment
LLMLibrary string
// Set via OLLAMA_MAX_LOADED_MODELS in the environment
MaxRunners int
// Set via OLLAMA_MAX_QUEUE in the environment
MaxQueuedRequests int
// Set via OLLAMA_MAX_VRAM in the environment
MaxVRAM uint64
// Set via OLLAMA_MODELS in the environment
ModelsDir string
// Set via OLLAMA_MAX_VRAM in the environment
MaxVRAM uint64
// Set via OLLAMA_NOHISTORY in the environment
NoHistory bool
// Set via OLLAMA_NOPRUNE in the environment
@@ -134,7 +132,6 @@ func init() {
NumParallel = 0 // Autoselect
MaxRunners = 0 // Autoselect
MaxQueuedRequests = 512
KeepAlive = 5 * time.Minute
LoadConfig()
}
@@ -269,10 +266,7 @@ func LoadConfig() {
}
}
ka := clean("OLLAMA_KEEP_ALIVE")
if ka != "" {
loadKeepAlive(ka)
}
KeepAlive = clean("OLLAMA_KEEP_ALIVE")
var err error
ModelsDir, err = getModelsDir()
@@ -350,24 +344,3 @@ func getOllamaHost() (*OllamaHost, error) {
Port: port,
}, nil
}
func loadKeepAlive(ka string) {
v, err := strconv.Atoi(ka)
if err != nil {
d, err := time.ParseDuration(ka)
if err == nil {
if d < 0 {
KeepAlive = time.Duration(math.MaxInt64)
} else {
KeepAlive = d
}
}
} else {
d := time.Duration(v) * time.Second
if d < 0 {
KeepAlive = time.Duration(math.MaxInt64)
} else {
KeepAlive = d
}
}
}

View File

@@ -2,10 +2,8 @@ package envconfig
import (
"fmt"
"math"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -25,21 +23,6 @@ func TestConfig(t *testing.T) {
t.Setenv("OLLAMA_FLASH_ATTENTION", "1")
LoadConfig()
require.True(t, FlashAttention)
t.Setenv("OLLAMA_KEEP_ALIVE", "")
LoadConfig()
require.Equal(t, 5*time.Minute, KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "3")
LoadConfig()
require.Equal(t, 3*time.Second, KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "1h")
LoadConfig()
require.Equal(t, 1*time.Hour, KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "-1s")
LoadConfig()
require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "-1")
LoadConfig()
require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
}
func TestClientFromEnvironment(t *testing.T) {

3
go.mod
View File

@@ -18,7 +18,6 @@ require (
require (
github.com/agnivade/levenshtein v1.1.1
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/google/go-cmp v0.6.0
github.com/mattn/go-runewidth v0.0.14
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
@@ -72,7 +71,7 @@ require (
golang.org/x/net v0.25.0 // indirect
golang.org/x/sys v0.20.0
golang.org/x/term v0.20.0
golang.org/x/text v0.15.0
golang.org/x/text v0.15.0 // indirect
google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -49,17 +49,9 @@ func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
}
func commonAMDValidateLibDir() (string, error) {
// Favor our bundled version
// Installer payload location if we're running the installed binary
exe, err := os.Executable()
if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
}
}
// We try to favor system paths first, so that we can wire up the subprocess to use
// the system version. Only use our bundled version if the system version doesn't work
// This gives users a more recovery options if versions have subtle problems at runtime
// Prefer explicit HIP env var
hipPath := os.Getenv("HIP_PATH")
@@ -95,5 +87,14 @@ func commonAMDValidateLibDir() (string, error) {
}
}
// Installer payload location if we're running the installed binary
exe, err := os.Executable()
if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
}
}
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
}

View File

@@ -84,8 +84,9 @@ func (hl *HipLib) AMDDriverVersion() (driverMajor, driverMinor int, err error) {
}
slog.Debug("hipDriverGetVersion", "version", version)
driverMajor = version / 10000000
driverMinor = (version - (driverMajor * 10000000)) / 100000
// TODO - this isn't actually right, but the docs claim hipDriverGetVersion isn't accurate anyway...
driverMajor = version / 1000
driverMinor = (version - (driverMajor * 1000)) / 10
return driverMajor, driverMinor, nil
}

View File

@@ -22,8 +22,8 @@ const (
var (
// Used to validate if the given ROCm lib is usable
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // This is not sufficient to discern v5 vs v6
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\6.1\\bin"} // TODO glob?
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\5.7\\bin"} // TODO glob?
)
func AMDGetGPUInfo() []RocmGPUInfo {
@@ -35,11 +35,12 @@ func AMDGetGPUInfo() []RocmGPUInfo {
}
defer hl.Release()
driverMajor, driverMinor, err := hl.AMDDriverVersion()
if err != nil {
// For now this is benign, but we may eventually need to fail compatibility checks
slog.Debug("error looking up amd driver version", "error", err)
}
// TODO - this reports incorrect version information, so omitting for now
// driverMajor, driverMinor, err := hl.AMDDriverVersion()
// if err != nil {
// // For now this is benign, but we may eventually need to fail compatibility checks
// slog.Debug("error looking up amd driver version", "error", err)
// }
// Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
count := hl.HipGetDeviceCount()
@@ -131,8 +132,10 @@ func AMDGetGPUInfo() []RocmGPUInfo {
MinimumMemory: rocmMinimumMemory,
Name: name,
Compute: gfx,
DriverMajor: driverMajor,
DriverMinor: driverMinor,
// TODO - this information isn't accurate on windows, so don't report it until we find the right way to retrieve
// DriverMajor: driverMajor,
// DriverMinor: driverMinor,
},
index: i,
}

View File

@@ -202,7 +202,7 @@ func GetGPUInfo() GpuInfoList {
}()
if !bootstrapped {
slog.Info("looking for compatible GPUs")
slog.Debug("Detecting GPUs")
needRefresh = false
cpuCapability = GetCPUCapability()
var memInfo C.mem_info_t
@@ -274,28 +274,6 @@ func GetGPUInfo() GpuInfoList {
gpuInfo.DriverMajor = driverMajor
gpuInfo.DriverMinor = driverMinor
// query the management library as well so we can record any skew between the two
// which represents overhead on the GPU we must set aside on subsequent updates
if cHandles.nvml != nil {
C.nvml_get_free(*cHandles.nvml, C.int(gpuInfo.index), &memInfo.free, &memInfo.total, &memInfo.used)
if memInfo.err != nil {
slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
C.free(unsafe.Pointer(memInfo.err))
} else {
if memInfo.free != 0 && uint64(memInfo.free) > gpuInfo.FreeMemory {
gpuInfo.OSOverhead = uint64(memInfo.free) - gpuInfo.FreeMemory
slog.Info("detected OS VRAM overhead",
"id", gpuInfo.ID,
"library", gpuInfo.Library,
"compute", gpuInfo.Compute,
"driver", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor),
"name", gpuInfo.Name,
"overhead", format.HumanBytes2(gpuInfo.OSOverhead),
)
}
}
}
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
cudaGPUs = append(cudaGPUs, gpuInfo)
}
@@ -342,9 +320,6 @@ func GetGPUInfo() GpuInfoList {
rocmGPUs = AMDGetGPUInfo()
bootstrapped = true
if len(cudaGPUs) == 0 && len(rocmGPUs) == 0 && len(oneapiGPUs) == 0 {
slog.Info("no compatible GPUs were discovered")
}
}
// For detected GPUs, load library if not loaded
@@ -360,17 +335,14 @@ func GetGPUInfo() GpuInfoList {
"before",
"total", format.HumanBytes2(cpus[0].TotalMemory),
"free", format.HumanBytes2(cpus[0].FreeMemory),
"free_swap", format.HumanBytes2(cpus[0].FreeSwap),
),
slog.Group(
"now",
"total", format.HumanBytes2(mem.TotalMemory),
"free", format.HumanBytes2(mem.FreeMemory),
"free_swap", format.HumanBytes2(mem.FreeSwap),
),
)
cpus[0].FreeMemory = mem.FreeMemory
cpus[0].FreeSwap = mem.FreeSwap
}
var memInfo C.mem_info_t
@@ -399,14 +371,9 @@ func GetGPUInfo() GpuInfoList {
slog.Warn("error looking up nvidia GPU memory")
continue
}
if cHandles.nvml != nil && gpu.OSOverhead > 0 {
// When using the management library update based on recorded overhead
memInfo.free -= C.uint64_t(gpu.OSOverhead)
}
slog.Debug("updating cuda memory data",
"gpu", gpu.ID,
"name", gpu.Name,
"overhead", format.HumanBytes2(gpu.OSOverhead),
slog.Group(
"before",
"total", format.HumanBytes2(gpu.TotalMemory),
@@ -547,23 +514,7 @@ func LoadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string) {
defer C.free(unsafe.Pointer(lib))
C.nvcuda_init(lib, &resp)
if resp.err != nil {
// Decide what log level based on the type of error message to help users understand why
msg := C.GoString(resp.err)
switch resp.cudaErr {
case C.CUDA_ERROR_INSUFFICIENT_DRIVER, C.CUDA_ERROR_SYSTEM_DRIVER_MISMATCH:
slog.Warn("version mismatch between driver and cuda driver library - reboot or upgrade may be required", "library", libPath, "error", msg)
case C.CUDA_ERROR_NO_DEVICE:
slog.Info("no nvidia devices detected", "library", libPath)
case C.CUDA_ERROR_UNKNOWN:
slog.Warn("unknown error initializing cuda driver library", "library", libPath, "error", msg)
slog.Warn("see https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for more information")
default:
if strings.Contains(msg, "wrong ELF class") {
slog.Debug("skipping 32bit library", "library", libPath)
} else {
slog.Info("unable to load cuda driver library", "library", libPath, "error", msg)
}
}
slog.Debug("Unable to load nvcuda", "library", libPath, "error", C.GoString(resp.err))
C.free(unsafe.Pointer(resp.err))
} else {
return int(resp.num_devices), &resp.ch, libPath

View File

@@ -56,8 +56,7 @@ func GetCPUInfo() GpuInfoList {
func GetCPUMem() (memInfo, error) {
return memInfo{
TotalMemory: uint64(C.getPhysicalMemory()),
FreeMemory: uint64(C.getFreeMemory()),
// FreeSwap omitted as Darwin uses dynamic paging
FreeMemory: 0,
}, nil
}

View File

@@ -2,4 +2,3 @@
#include <stdint.h>
uint64_t getRecommendedMaxVRAM();
uint64_t getPhysicalMemory();
uint64_t getFreeMemory();

View File

@@ -1,5 +1,4 @@
#import <Foundation/Foundation.h>
#import <mach/mach.h>
// go:build darwin
#include "gpu_info_darwin.h"
uint64_t getRecommendedMaxVRAM() {
@@ -9,27 +8,6 @@ uint64_t getRecommendedMaxVRAM() {
return result;
}
// getPhysicalMemory returns the total physical memory in bytes
uint64_t getPhysicalMemory() {
return [NSProcessInfo processInfo].physicalMemory;
}
// getFreeMemory returns the total free memory in bytes, including inactive
// memory that can be reclaimed by the system.
uint64_t getFreeMemory() {
mach_port_t host_port = mach_host_self();
mach_msg_type_number_t host_size = sizeof(vm_statistics64_data_t) / sizeof(integer_t);
vm_size_t pagesize;
vm_statistics64_data_t vm_stat;
host_page_size(host_port, &pagesize);
if (host_statistics64(host_port, HOST_VM_INFO64, (host_info64_t)&vm_stat, &host_size) != KERN_SUCCESS) {
return 0;
}
uint64_t free_memory = (uint64_t)vm_stat.free_count * pagesize;
free_memory += (uint64_t)vm_stat.speculative_count * pagesize;
free_memory += (uint64_t)vm_stat.inactive_count * pagesize;
return free_memory;
return [[NSProcessInfo processInfo] physicalMemory];
}

View File

@@ -7,7 +7,6 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
CUresult ret;
resp->err = NULL;
resp->num_devices = 0;
resp->cudaErr = CUDA_SUCCESS;
const int buflen = 256;
char buf[buflen + 1];
int i;
@@ -39,7 +38,6 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
nvcuda_lib_path, msg);
free(msg);
resp->err = strdup(buf);
resp->cudaErr = -1;
return;
}
@@ -54,7 +52,6 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
msg);
free(msg);
resp->err = strdup(buf);
resp->cudaErr = -1;
return;
}
}
@@ -64,9 +61,12 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
LOG(resp->ch.verbose, "cuInit err: %d\n", ret);
UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL;
snprintf(buf, buflen, "cuda driver library init failure: %d", ret);
if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama");
return;
}
snprintf(buf, buflen, "nvcuda init failure: %d", ret);
resp->err = strdup(buf);
resp->cudaErr = ret;
return;
}
@@ -91,7 +91,6 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
resp->ch.handle = NULL;
snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf);
resp->cudaErr = ret;
return;
}
}
@@ -107,13 +106,13 @@ void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) {
CUuuid uuid = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
if (h.handle == NULL) {
resp->err = strdup("cuda driver library handle isn't initialized");
resp->err = strdup("nvcuda handle isn't initialized");
return;
}
ret = (*h.cuDeviceGet)(&device, i);
if (ret != CUDA_SUCCESS) {
snprintf(buf, buflen, "cuda driver library device failed to initialize");
snprintf(buf, buflen, "nvcuda device failed to initialize");
resp->err = strdup(buf);
return;
}
@@ -169,14 +168,14 @@ void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) {
// To get memory we have to set (and release) a context
ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
if (ret != CUDA_SUCCESS) {
snprintf(buf, buflen, "cuda driver library failed to get device context %d", ret);
snprintf(buf, buflen, "nvcuda failed to get device context %d", ret);
resp->err = strdup(buf);
return;
}
ret = (*h.cuMemGetInfo_v2)(&memInfo.free, &memInfo.total);
if (ret != CUDA_SUCCESS) {
snprintf(buf, buflen, "cuda driver library device memory info lookup failure %d", ret);
snprintf(buf, buflen, "nvcuda device memory info lookup failure %d", ret);
resp->err = strdup(buf);
// Best effort on failure...
(*h.cuCtxDestroy)(ctx);
@@ -194,7 +193,7 @@ void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) {
ret = (*h.cuCtxDestroy)(ctx);
if (ret != CUDA_SUCCESS) {
LOG(1, "cuda driver library failed to release device context %d", ret);
LOG(1, "nvcuda failed to release device context %d", ret);
}
}
@@ -207,7 +206,7 @@ void nvcuda_get_free(nvcuda_handle_t h, int i, uint64_t *free, uint64_t *total)
ret = (*h.cuDeviceGet)(&device, i);
if (ret != CUDA_SUCCESS) {
LOG(1, "cuda driver library device failed to initialize");
LOG(1, "nvcuda device failed to initialize");
return;
}
@@ -215,13 +214,13 @@ void nvcuda_get_free(nvcuda_handle_t h, int i, uint64_t *free, uint64_t *total)
// To get memory we have to set (and release) a context
ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
if (ret != CUDA_SUCCESS) {
LOG(1, "cuda driver library failed to get device context %d", ret);
LOG(1, "nvcuda failed to get device context %d", ret);
return;
}
ret = (*h.cuMemGetInfo_v2)(free, total);
if (ret != CUDA_SUCCESS) {
LOG(1, "cuda driver library device memory info lookup failure %d", ret);
LOG(1, "nvcuda device memory info lookup failure %d", ret);
// Best effort on failure...
(*h.cuCtxDestroy)(ctx);
return;
@@ -229,12 +228,12 @@ void nvcuda_get_free(nvcuda_handle_t h, int i, uint64_t *free, uint64_t *total)
ret = (*h.cuCtxDestroy)(ctx);
if (ret != CUDA_SUCCESS) {
LOG(1, "cuda driver library failed to release device context %d", ret);
LOG(1, "nvcuda failed to release device context %d", ret);
}
}
void nvcuda_release(nvcuda_handle_t h) {
LOG(h.verbose, "releasing cuda driver library\n");
LOG(h.verbose, "releasing nvcuda library\n");
UNLOAD_LIBRARY(h.handle);
// TODO and other context release logic?
h.handle = NULL;

View File

@@ -7,12 +7,9 @@
typedef enum cudaError_enum {
CUDA_SUCCESS = 0,
CUDA_ERROR_INVALID_VALUE = 1,
CUDA_ERROR_OUT_OF_MEMORY = 2,
CUDA_ERROR_MEMORY_ALLOCATION = 2,
CUDA_ERROR_NOT_INITIALIZED = 3,
CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
CUDA_ERROR_NO_DEVICE = 100,
CUDA_ERROR_SYSTEM_DRIVER_MISMATCH = 803,
CUDA_ERROR_UNKNOWN = 999,
// Other values omitted for now...
} CUresult;
@@ -67,7 +64,6 @@ typedef struct nvcuda_init_resp {
char *err; // If err is non-null handle is invalid
nvcuda_handle_t ch;
int num_devices;
CUresult cudaErr;
} nvcuda_init_resp_t;
void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp);

View File

@@ -50,7 +50,7 @@ var OneapiMgmtName = "libze_intel_gpu.so"
func GetCPUMem() (memInfo, error) {
var mem memInfo
var total, available, free, buffers, cached, freeSwap uint64
var total, available, free, buffers, cached uint64
f, err := os.Open("/proc/meminfo")
if err != nil {
return mem, err
@@ -70,21 +70,20 @@ func GetCPUMem() (memInfo, error) {
_, err = fmt.Sscanf(line, "Buffers:%d", &buffers)
case strings.HasPrefix(line, "Cached:"):
_, err = fmt.Sscanf(line, "Cached:%d", &cached)
case strings.HasPrefix(line, "SwapFree:"):
_, err = fmt.Sscanf(line, "SwapFree:%d", &freeSwap)
default:
continue
}
if err != nil {
return mem, err
}
if total > 0 && available > 0 {
mem.TotalMemory = total * format.KibiByte
mem.FreeMemory = available * format.KibiByte
return mem, nil
}
}
mem.TotalMemory = total * format.KibiByte
mem.FreeSwap = freeSwap * format.KibiByte
if available > 0 {
mem.FreeMemory = available * format.KibiByte
} else {
mem.FreeMemory = (free + buffers + cached) * format.KibiByte
}
mem.FreeMemory = (free + buffers + cached) * format.KibiByte
return mem, nil
}

View File

@@ -51,5 +51,5 @@ func GetCPUMem() (memInfo, error) {
if r1 == 0 {
return memInfo{}, fmt.Errorf("GlobalMemoryStatusEx failed: %w", err)
}
return memInfo{TotalMemory: memStatus.TotalPhys, FreeMemory: memStatus.AvailPhys, FreeSwap: memStatus.AvailPageFile}, nil
return memInfo{TotalMemory: memStatus.TotalPhys, FreeMemory: memStatus.AvailPhys}, nil
}

View File

@@ -10,7 +10,6 @@ import (
type memInfo struct {
TotalMemory uint64 `json:"total_memory,omitempty"`
FreeMemory uint64 `json:"free_memory,omitempty"`
FreeSwap uint64 `json:"free_swap,omitempty"`
}
// Beginning of an `ollama info` command
@@ -53,8 +52,7 @@ type CPUInfo struct {
type CudaGPUInfo struct {
GpuInfo
OSOverhead uint64 // Memory overhead between the driver library and management library
index int //nolint:unused,nolintlint
index int //nolint:unused,nolintlint
}
type CudaGPUInfoList []CudaGPUInfo

View File

@@ -1,152 +0,0 @@
//go:build integration
package integration
import (
"context"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func TestAllMiniLMEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
req := api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
}
res, err := embedTestHelper(ctx, t, req)
if err != nil {
t.Fatalf("error: %v", err)
}
if len(res.Embeddings) != 1 {
t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
}
if len(res.Embeddings[0]) != 384 {
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
}
if res.Embeddings[0][0] != 0.010071031 {
t.Fatalf("expected 0.010071031, got %f", res.Embeddings[0][0])
}
}
func TestAllMiniLMBatchEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
req := api.EmbedRequest{
Model: "all-minilm",
Input: []string{"why is the sky blue?", "why is the grass green?"},
}
res, err := embedTestHelper(ctx, t, req)
if err != nil {
t.Fatalf("error: %v", err)
}
if len(res.Embeddings) != 2 {
t.Fatalf("expected 2 embeddings, got %d", len(res.Embeddings))
}
if len(res.Embeddings[0]) != 384 {
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
}
if res.Embeddings[0][0] != 0.010071031 || res.Embeddings[1][0] != -0.009802706 {
t.Fatalf("expected 0.010071031 and -0.009802706, got %f and %f", res.Embeddings[0][0], res.Embeddings[1][0])
}
}
func TestAllMiniLmEmbedTruncate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
truncTrue, truncFalse := true, false
type testReq struct {
Name string
Request api.EmbedRequest
}
reqs := []testReq{
{
Name: "Target Truncation",
Request: api.EmbedRequest{
Model: "all-minilm",
Input: "why",
},
},
{
Name: "Default Truncate",
Request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Options: map[string]any{"num_ctx": 1},
},
},
{
Name: "Explicit Truncate",
Request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 1},
},
},
}
res := make(map[string]*api.EmbedResponse)
for _, req := range reqs {
response, err := embedTestHelper(ctx, t, req.Request)
if err != nil {
t.Fatalf("error: %v", err)
}
res[req.Name] = response
}
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
t.Fatal("expected default request to truncate correctly")
}
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
t.Fatal("expected default request and truncate true request to be the same")
}
// check that truncate set to false returns an error if context length is exceeded
_, err := embedTestHelper(ctx, t, api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 1},
})
if err == nil {
t.Fatal("expected error, got nil")
}
}
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err)
}
response, err := client.Embed(ctx, &req)
if err != nil {
return nil, err
}
return response, nil
}

View File

@@ -1,13 +1,14 @@
set(TARGET ollama_llama_server)
option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h)
install(TARGETS ${TARGET} RUNTIME)
target_compile_definitions(${TARGET} PRIVATE
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
)
target_link_libraries(${TARGET} PRIVATE ggml llama common llava ${CMAKE_THREAD_LIBS_INIT})
if (WIN32)
TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32)
endif()
set(TARGET ollama_llama_server)
option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h)
install(TARGETS ${TARGET} RUNTIME)
target_compile_definitions(${TARGET} PRIVATE
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
)
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
if (WIN32)
TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32)
endif()
target_compile_features(${TARGET} PRIVATE cxx_std_11)

View File

@@ -1382,50 +1382,12 @@ struct llama_server_context
}
}
std::string common_prefix(const std::string& str1, const std::string& str2) {
auto mismatch_pair = std::mismatch(str1.begin(), str1.end(), str2.begin());
return std::string(str1.begin(), mismatch_pair.first);
}
// Find the slot that has the greatest common prefix
server_slot *prefix_slot(const json &prompt) {
if (!prompt.is_string()) {
return nullptr;
}
std::string prompt_str = prompt.get<std::string>();
server_slot *slot = nullptr;
size_t longest = 0;
for (server_slot &s : slots) {
if (s.available() && s.prompt.is_string()) {
std::string s_prompt = s.prompt.get<std::string>();
std::string prefix = common_prefix(s_prompt, prompt_str);
if (prefix.size() > longest) {
slot = &s;
longest = prefix.size();
}
}
}
if (!slot) {
return get_slot(-1);
}
LOG_DEBUG("slot with common prefix found", {{
"slot_id", slot->id,
"characters", longest
}});
return slot;
}
void process_single_task(task_server& task)
{
switch (task.type)
{
case TASK_TYPE_COMPLETION: {
server_slot *slot = prefix_slot(task.data["prompt"]);
server_slot *slot = get_slot(json_value(task.data, "slot_id", -1));
if (slot == nullptr)
{
// if no slot is available, we defer this task for processing later
@@ -1688,8 +1650,22 @@ struct llama_server_context
}
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
char buf[256];
llama_model_meta_val_str(model, "general.architecture", buf, 256);
bool gemma2 = strcmp(buf, "gemma2") == 0;
int32_t truncate_at = slot.n_ctx;
// truncate at 2/3 of the context length for gemma2 models
// as they do not support context shifts (from the sliding window implementation).
// this way, prompts that almost fit the context length can still generate a full
// response without a sudden stop from hitting the context limit
if (gemma2) {
truncate_at = 2 * slot.n_ctx / 3;
}
// if input prompt is too big, truncate it, if group attention self-extend is disabled
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx)
if (slot.ga_n == 1 && slot.n_prompt_tokens >= truncate_at)
{
const int n_left = slot.n_ctx - slot.params.n_keep;
const int n_shift = n_left / 2;
@@ -1717,6 +1693,19 @@ struct llama_server_context
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
}
// Models with sliding window attention do not work with context shifts, so
// limit their prediction to the context length
if (gemma2) {
int32_t limit = slot.n_ctx - slot.n_prompt_tokens;
slot.n_predict = limit;
slot.params.n_predict = limit;
LOG_INFO("model does not support sliding window, limiting generation", {
{"n_ctx", slot.n_ctx},
{"n_prompt_tokens", slot.n_prompt_tokens},
{"n_predict", slot.n_predict}
});
}
if (!slot.params.cache_prompt)
{
llama_sampling_reset(slot.ctx_sampling);
@@ -1743,7 +1732,7 @@ struct llama_server_context
slot.n_past -= 1;
}
slot.n_prompt_tokens_processed = slot.n_prompt_tokens;
slot.n_prompt_tokens_processed = slot.n_prompt_tokens - slot.n_past;
if (slot.ga_n != 1)
{
@@ -3188,33 +3177,26 @@ int main(int argc, char **argv) {
prompt = "";
}
if (prompt.size() == 1) {
prompt = prompt[0];
json image_data;
if (body.count("image_data") != 0) {
image_data = body["image_data"];
}
else
{
image_data = "";
}
// create and queue the task
json responses;
{
const int id_task = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(id_task);
llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, true, -1);
// get the result
task_result result = llama.queue_results.recv(id_task);
llama.queue_results.remove_waiting_task_id(id_task);
if (result.error) {
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
}
// get the result
task_result result = llama.queue_results.recv(task_id);
llama.queue_results.remove_waiting_task_id(task_id);
responses = result.result_json.value("results", std::vector<json>{result.result_json});
json embeddings = json::array();
for (auto & elem : responses) {
embeddings.push_back(elem.at("embedding"));
}
// send the result
json embedding_res = json{{"embedding", embeddings}};
return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
}
// send the result
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
});
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?

View File

@@ -18,16 +18,16 @@ sign() {
fi
}
COMMON_DARWIN_DEFS="-DBUILD_SHARED_LIBS=off -DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DLLAMA_METAL_MACOSX_VERSION_MIN=11.3 -DCMAKE_SYSTEM_NAME=Darwin -DGGML_METAL_EMBED_LIBRARY=on -DGGML_OPENMP=off"
COMMON_DARWIN_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DLLAMA_METAL_MACOSX_VERSION_MIN=11.3 -DCMAKE_SYSTEM_NAME=Darwin -DLLAMA_METAL_EMBED_LIBRARY=on -DLLAMA_OPENMP=off"
case "${GOARCH}" in
"amd64")
COMMON_CPU_DEFS="${COMMON_DARWIN_DEFS} -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DGGML_METAL=off -DGGML_NATIVE=off"
COMMON_CPU_DEFS="${COMMON_DARWIN_DEFS} -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=off -DLLAMA_NATIVE=off"
# Static build for linking into the Go binary
init_vars
CMAKE_TARGETS="--target llama --target ggml"
CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_BLAS=off -DGGML_ACCELERATE=off -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off ${CMAKE_DEFS}"
CMAKE_DEFS="${COMMON_CPU_DEFS} -DBUILD_SHARED_LIBS=off -DLLAMA_BLAS=off -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}_static"
echo "Building static library"
build
@@ -37,7 +37,7 @@ case "${GOARCH}" in
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
#
init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_ACCELERATE=off -DGGML_BLAS=off -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off ${CMAKE_DEFS}"
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_BLAS=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/cpu"
echo "Building LCD CPU"
build
@@ -49,7 +49,7 @@ case "${GOARCH}" in
# Approximately 400% faster than LCD on same CPU
#
init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_ACCELERATE=off -DGGML_BLAS=off -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off ${CMAKE_DEFS}"
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_BLAS=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx"
echo "Building AVX CPU"
build
@@ -61,7 +61,7 @@ case "${GOARCH}" in
# Approximately 10% faster than AVX on same CPU
#
init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_ACCELERATE=on -DGGML_BLAS=off -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on ${CMAKE_DEFS}"
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=on -DLLAMA_BLAS=off -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx2"
echo "Building AVX2 CPU"
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation"
@@ -75,14 +75,14 @@ case "${GOARCH}" in
# Static build for linking into the Go binary
init_vars
CMAKE_TARGETS="--target llama --target ggml"
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DCMAKE_SYSTEM_NAME=Darwin -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} ${CMAKE_DEFS}"
CMAKE_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DLLAMA_BLAS=off -DCMAKE_SYSTEM_NAME=Darwin -DBUILD_SHARED_LIBS=off -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=off -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}_static"
echo "Building static library"
build
if [ -z "$OLLAMA_SKIP_METAL_GENERATE" ]; then
init_vars
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} ${CMAKE_DEFS}"
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/metal"
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders"
build

View File

@@ -51,7 +51,7 @@ if [ -z "${CUDACXX}" ]; then
export CUDACXX=$(command -v nvcc)
fi
fi
COMMON_CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DCMAKE_POSITION_INDEPENDENT_CODE=on -DGGML_NATIVE=off -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_OPENMP=off"
COMMON_CMAKE_DEFS="-DCMAKE_POSITION_INDEPENDENT_CODE=on -DLLAMA_NATIVE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off -DLLAMA_OPENMP=off"
source $(dirname $0)/gen_common.sh
init_vars
git_module_setup
@@ -64,7 +64,7 @@ if [ -z "${OLLAMA_SKIP_STATIC_GENERATE}" -o "${OLLAMA_CPU_TARGET}" = "static" ];
# Static build for linking into the Go binary
init_vars
CMAKE_TARGETS="--target llama --target ggml"
CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DGGML_NATIVE=off -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_OPENMP=off ${CMAKE_DEFS}"
CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DLLAMA_NATIVE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off -DLLAMA_OPENMP=off ${CMAKE_DEFS}"
BUILD_DIR="../build/linux/${ARCH}_static"
echo "Building static library"
build
@@ -77,29 +77,29 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
if [ -n "${OLLAMA_CUSTOM_CPU_DEFS}" ]; then
init_vars
echo "OLLAMA_CUSTOM_CPU_DEFS=\"${OLLAMA_CUSTOM_CPU_DEFS}\""
CMAKE_DEFS="${OLLAMA_CUSTOM_CPU_DEFS} -DBUILD_SHARED_LIBS=off -DCMAKE_POSITION_INDEPENDENT_CODE=on ${CMAKE_DEFS}"
CMAKE_DEFS="${OLLAMA_CUSTOM_CPU_DEFS} -DCMAKE_POSITION_INDEPENDENT_CODE=on ${CMAKE_DEFS}"
BUILD_DIR="../build/linux/${ARCH}/cpu"
echo "Building custom CPU"
build
compress
else
# Darwin Rosetta x86 emulation does NOT support AVX, AVX2, AVX512
# -DGGML_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
# -DGGML_F16C -- 2012 Intel Ivy Bridge & AMD 2011 Bulldozer (No significant improvement over just AVX)
# -DGGML_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
# -DGGML_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
# -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
# -DLLAMA_F16C -- 2012 Intel Ivy Bridge & AMD 2011 Bulldozer (No significant improvement over just AVX)
# -DLLAMA_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
# -DLLAMA_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
# Note: the following seem to yield slower results than AVX2 - ymmv
# -DGGML_AVX512 -- 2017 Intel Skylake and High End DeskTop (HEDT)
# -DGGML_AVX512_VBMI -- 2018 Intel Cannon Lake
# -DGGML_AVX512_VNNI -- 2021 Intel Alder Lake
# -DLLAMA_AVX512 -- 2017 Intel Skylake and High End DeskTop (HEDT)
# -DLLAMA_AVX512_VBMI -- 2018 Intel Cannon Lake
# -DLLAMA_AVX512_VNNI -- 2021 Intel Alder Lake
COMMON_CPU_DEFS="-DBUILD_SHARED_LIBS=off -DCMAKE_POSITION_INDEPENDENT_CODE=on -DGGML_NATIVE=off -DGGML_OPENMP=off"
COMMON_CPU_DEFS="-DCMAKE_POSITION_INDEPENDENT_CODE=on -DLLAMA_NATIVE=off -DLLAMA_OPENMP=off"
if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "cpu" ]; then
#
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
#
init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off ${CMAKE_DEFS}"
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/linux/${ARCH}/cpu"
echo "Building LCD CPU"
build
@@ -116,7 +116,7 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
# Approximately 400% faster than LCD on same CPU
#
init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off ${CMAKE_DEFS}"
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/linux/${ARCH}/cpu_avx"
echo "Building AVX CPU"
build
@@ -129,7 +129,7 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
# Approximately 10% faster than AVX on same CPU
#
init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on ${CMAKE_DEFS}"
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
BUILD_DIR="../build/linux/${ARCH}/cpu_avx2"
echo "Building AVX2 CPU"
build
@@ -170,15 +170,15 @@ if [ -z "${OLLAMA_SKIP_CUDA_GENERATE}" -a -d "${CUDA_LIB_DIR}" ]; then
#
# CUDA compute < 6.0 lacks proper FP16 support on ARM.
# Disabling has minimal performance effect while maintaining compatibility.
ARM64_DEFS="-DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_CUDA_F16=off"
ARM64_DEFS="-DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_CUDA_F16=off"
fi
# Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
if [ -n "${OLLAMA_CUSTOM_CUDA_DEFS}" ]; then
echo "OLLAMA_CUSTOM_CUDA_DEFS=\"${OLLAMA_CUSTOM_CUDA_DEFS}\""
CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${OLLAMA_CUSTOM_CUDA_DEFS}"
CMAKE_CUDA_DEFS="-DLLAMA_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${OLLAMA_CUSTOM_CUDA_DEFS}"
echo "Building custom CUDA GPU"
else
CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_FLAGS=-t8 -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}"
CMAKE_CUDA_DEFS="-DLLAMA_CUDA=on -DCMAKE_CUDA_FLAGS=-t8 -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}"
fi
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS} ${CMAKE_CUDA_DEFS}"
BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}"
@@ -216,7 +216,7 @@ if [ -z "${OLLAMA_SKIP_ONEAPI_GENERATE}" -a -d "${ONEAPI_ROOT}" ]; then
init_vars
source ${ONEAPI_ROOT}/setvars.sh --force # set up environment variables for oneAPI
CC=icx
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL=ON -DGGML_SYCL_F16=OFF"
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DLLAMA_SYCL=ON -DLLAMA_SYCL_F16=OFF"
BUILD_DIR="../build/linux/${ARCH}/oneapi"
EXTRA_LIBS="-fsycl -Wl,-rpath,${ONEAPI_ROOT}/compiler/latest/lib,-rpath,${ONEAPI_ROOT}/mkl/latest/lib,-rpath,${ONEAPI_ROOT}/tbb/latest/lib,-rpath,${ONEAPI_ROOT}/compiler/latest/opt/oclfpga/linux64/lib -lOpenCL -lmkl_core -lmkl_sycl_blas -lmkl_intel_ilp64 -lmkl_tbb_thread -ltbb"
DEBUG_FLAGS="" # icx compiles with -O0 if we pass -g, so we must remove it
@@ -254,7 +254,7 @@ if [ -z "${OLLAMA_SKIP_ROCM_GENERATE}" -a -d "${ROCM_PATH}" ]; then
ROCM_VARIANT=_v$(ls ${ROCM_PATH}/lib/librocblas.so.*.*.????? | cut -f5 -d. || true)
fi
init_vars
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DGGML_HIPBLAS=on -DLLAMA_CUDA_NO_PEER_COPY=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DLLAMA_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
# Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
if [ -n "${OLLAMA_CUSTOM_ROCM_DEFS}" ]; then
echo "OLLAMA_CUSTOM_ROCM_DEFS=\"${OLLAMA_CUSTOM_ROCM_DEFS}\""

View File

@@ -6,9 +6,18 @@ function amdGPUs {
if ($env:AMDGPU_TARGETS) {
return $env:AMDGPU_TARGETS
}
# Current supported rocblas list from ROCm v6.1.2 on windows
# TODO - load from some common data file for linux + windows build consistency
$GPU_LIST = @(
"gfx900"
"gfx906:xnack-"
"gfx908:xnack-"
"gfx90a:xnack+"
"gfx90a:xnack-"
"gfx940"
"gfx941"
"gfx942"
"gfx1010"
"gfx1012"
"gfx1030"
"gfx1100"
"gfx1101"
@@ -30,8 +39,8 @@ function init_vars {
}
$script:cmakeDefs = @(
"-DBUILD_SHARED_LIBS=on",
"-DGGML_NATIVE=off",
"-DGGML_OPENMP=off"
"-DLLAMA_NATIVE=off",
"-DLLAMA_OPENMP=off"
)
$script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on")
$script:ARCH = $Env:PROCESSOR_ARCHITECTURE.ToLower()
@@ -173,9 +182,9 @@ function cleanup {
}
# -DGGML_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
# -DGGML_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
# -DGGML_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
# -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
# -DLLAMA_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
# -DLLAMA_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
function build_static() {
@@ -195,13 +204,13 @@ function build_static() {
"-DCMAKE_C_COMPILER=gcc.exe",
"-DCMAKE_CXX_COMPILER=g++.exe",
"-DBUILD_SHARED_LIBS=off",
"-DGGML_NATIVE=off",
"-DGGML_AVX=off",
"-DGGML_AVX2=off",
"-DGGML_AVX512=off",
"-DGGML_F16C=off",
"-DGGML_FMA=off",
"-DGGML_OPENMP=off")
"-DLLAMA_NATIVE=off",
"-DLLAMA_AVX=off",
"-DLLAMA_AVX2=off",
"-DLLAMA_AVX512=off",
"-DLLAMA_F16C=off",
"-DLLAMA_FMA=off",
"-DLLAMA_OPENMP=off")
$script:buildDir="../build/windows/${script:ARCH}_static"
write-host "Building static library"
build
@@ -215,7 +224,7 @@ function build_cpu($gen_arch) {
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu"))) {
# remaining llama.cpp builds use MSVC
init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", $gen_arch, "-DGGML_AVX=off", "-DGGML_AVX2=off", "-DGGML_AVX512=off", "-DGGML_FMA=off", "-DGGML_F16C=off") + $script:cmakeDefs
$script:cmakeDefs = $script:commonCpuDefs + @("-A", $gen_arch, "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu"
$script:distDir="$script:DIST_BASE\cpu"
write-host "Building LCD CPU"
@@ -230,7 +239,7 @@ function build_cpu($gen_arch) {
function build_cpu_avx() {
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx"))) {
init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DGGML_AVX=on", "-DGGML_AVX2=off", "-DGGML_AVX512=off", "-DGGML_FMA=off", "-DGGML_F16C=off") + $script:cmakeDefs
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx"
$script:distDir="$script:DIST_BASE\cpu_avx"
write-host "Building AVX CPU"
@@ -245,7 +254,7 @@ function build_cpu_avx() {
function build_cpu_avx2() {
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx2"))) {
init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DGGML_AVX=on", "-DGGML_AVX2=on", "-DGGML_AVX512=off", "-DGGML_FMA=on", "-DGGML_F16C=on") + $script:cmakeDefs
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx2"
$script:distDir="$script:DIST_BASE\cpu_avx2"
write-host "Building AVX2 CPU"
@@ -270,9 +279,9 @@ function build_cuda() {
$script:distDir="$script:DIST_BASE\cuda$script:CUDA_VARIANT"
$script:cmakeDefs += @(
"-A", "x64",
"-DGGML_CUDA=ON",
"-DGGML_AVX=on",
"-DGGML_AVX2=off",
"-DLLAMA_CUDA=ON",
"-DLLAMA_AVX=on",
"-DLLAMA_AVX2=off",
"-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR",
"-DCMAKE_CUDA_FLAGS=-t8",
"-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}"
@@ -310,7 +319,7 @@ function build_oneapi() {
$script:distDir ="$script:DIST_BASE\oneapi$script:ONEAPI_VARIANT"
$script:cmakeDefs += @(
"-G", "MinGW Makefiles",
"-DGGML_SYCL=ON",
"-DLLAMA_SYCL=ON",
"-DCMAKE_C_COMPILER=icx",
"-DCMAKE_CXX_COMPILER=icx",
"-DCMAKE_BUILD_TYPE=Release"
@@ -356,11 +365,10 @@ function build_rocm() {
"-G", "Ninja",
"-DCMAKE_C_COMPILER=clang.exe",
"-DCMAKE_CXX_COMPILER=clang++.exe",
"-DGGML_HIPBLAS=on",
"-DLLAMA_CUDA_NO_PEER_COPY=on",
"-DLLAMA_HIPBLAS=on",
"-DHIP_PLATFORM=amd",
"-DGGML_AVX=on",
"-DGGML_AVX2=off",
"-DLLAMA_AVX=on",
"-DLLAMA_AVX2=off",
"-DCMAKE_POSITION_INDEPENDENT_CODE=on",
"-DAMDGPU_TARGETS=$(amdGPUs)",
"-DGPU_TARGETS=$(amdGPUs)"
@@ -386,6 +394,7 @@ function build_rocm() {
sign
install
# Assumes v5.7, may need adjustments for v6
rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
md "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\" -ea 0 > $null
cp "${env:HIP_PATH}\bin\hipblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"

View File

@@ -424,32 +424,6 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
)
case "chatglm":
fullOffload = 4 * batch * (embedding + vocab)
partialOffload = 4*batch*(embedding+vocab) + embedding*vocab*105/128
if qkvBias, ok := layers["blk.0"]["attn_qkv.bias"]; ok {
fullOffload = max(
fullOffload,
4*batch*(2+
2*embedding+
context+
context*heads+
embeddingHeadsK*heads+
qkvBias.Shape[0]),
)
partialOffload = max(
partialOffload,
4*batch*(1+
2*embedding+
embeddingHeadsK*heads+
context+
context*heads)+
4*embeddingHeadsK*context+
4*context*embeddingHeadsK+
4*qkvBias.Shape[0],
)
}
}
return

View File

@@ -537,7 +537,6 @@ var ggufKVOrder = map[string][]string{
"tokenizer.ggml.add_bos_token",
"tokenizer.ggml.add_eos_token",
"tokenizer.chat_template",
"bert.pooling_type",
},
}

View File

@@ -1,13 +1,12 @@
package llm
// #cgo CFLAGS: -Illama.cpp -Illama.cpp/include -Illama.cpp/ggml/include
// #cgo LDFLAGS: -lllama -lggml -lstdc++ -lpthread
// #cgo darwin,arm64 LDFLAGS: -L${SRCDIR}/build/darwin/arm64_static -L${SRCDIR}/build/darwin/arm64_static/src -L${SRCDIR}/build/darwin/arm64_static/ggml/src -framework Accelerate -framework Metal
// #cgo darwin,amd64 LDFLAGS: -L${SRCDIR}/build/darwin/x86_64_static -L${SRCDIR}/build/darwin/x86_64_static/src -L${SRCDIR}/build/darwin/x86_64_static/ggml/src
// #cgo windows,amd64 LDFLAGS: -static-libstdc++ -static-libgcc -static -L${SRCDIR}/build/windows/amd64_static -L${SRCDIR}/build/windows/amd64_static/src -L${SRCDIR}/build/windows/amd64_static/ggml/src
// #cgo windows,arm64 LDFLAGS: -static-libstdc++ -static-libgcc -static -L${SRCDIR}/build/windows/arm64_static -L${SRCDIR}/build/windows/arm64_static/src -L${SRCDIR}/build/windows/arm64_static/ggml/src
// #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/linux/x86_64_static -L${SRCDIR}/build/linux/x86_64_static/src -L${SRCDIR}/build/linux/x86_64_static/ggml/src
// #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src
// #cgo CFLAGS: -Illama.cpp
// #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/build/darwin/arm64_static/libllama.a -lstdc++
// #cgo darwin,amd64 LDFLAGS: ${SRCDIR}/build/darwin/x86_64_static/libllama.a -lstdc++
// #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/libllama.a -static -lstdc++
// #cgo windows,arm64 LDFLAGS: ${SRCDIR}/build/windows/arm64_static/libllama.a -static -lstdc++
// #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/libllama.a -lstdc++
// #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/libllama.a -lstdc++
// #include <stdlib.h>
// #include "llama.h"
import "C"
@@ -33,7 +32,7 @@ func Quantize(infile, outfile string, ftype fileType) error {
params.ftype = ftype.Value()
if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
return fmt.Errorf("failed to quantize model. This model architecture may not be supported, or you may need to upgrade Ollama to the latest version")
return fmt.Errorf("llama_model_quantize: %d", rc)
}
return nil

View File

@@ -1,8 +1,8 @@
diff --git a/common/common.cpp b/common/common.cpp
index 2c05a4d4..927f0e3d 100644
index 73ff0e85..6adb1a92 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -2093,6 +2093,8 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
@@ -2447,6 +2447,8 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
mparams.use_mmap = params.use_mmap;
mparams.use_mlock = params.use_mlock;
mparams.check_tensors = params.check_tensors;
@@ -12,10 +12,10 @@ index 2c05a4d4..927f0e3d 100644
mparams.kv_overrides = NULL;
} else {
diff --git a/common/common.h b/common/common.h
index 65c0ef81..ebca2c77 100644
index 58ed72f4..0bb2605e 100644
--- a/common/common.h
+++ b/common/common.h
@@ -184,6 +184,13 @@ struct gpt_params {
@@ -180,6 +180,13 @@ struct gpt_params {
std::string mmproj = ""; // path to multimodal projector
std::vector<std::string> image; // path to image file(s)
@@ -26,6 +26,6 @@ index 65c0ef81..ebca2c77 100644
+ // context pointer passed to the progress callback
+ void * progress_callback_user_data;
+
// embedding
bool embedding = false; // get only sentence embedding
int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
// server params
int32_t port = 8080; // server listens on this network port
int32_t timeout_read = 600; // http read timeout in seconds

View File

@@ -1,8 +1,17 @@
diff --git a/src/llama.cpp b/src/llama.cpp
index 73f52435..58a00fb1 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -7241,7 +7241,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
From 544a2d2e646d39e878d87dfbb3398a356bc560ab Mon Sep 17 00:00:00 2001
From: Michael Yang <mxyng@pm.me>
Date: Thu, 23 May 2024 11:18:45 -0700
Subject: [PATCH] throw exception on load errors
---
llama.cpp | 25 ++++++++++++++++---------
1 file changed, 16 insertions(+), 9 deletions(-)
diff --git a/llama.cpp b/llama.cpp
index 15c66077..8ba90b6a 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -6346,7 +6346,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
}
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
@@ -11,7 +20,7 @@ index 73f52435..58a00fb1 100644
}
return 0;
@@ -17564,16 +17564,23 @@ struct llama_model * llama_load_model_from_file(
@@ -15600,16 +15600,23 @@ struct llama_model * llama_load_model_from_file(
}
model->rpc_servers.push_back(servers);
}
@@ -43,3 +52,6 @@ index 73f52435..58a00fb1 100644
}
return model;
--
2.45.1

View File

@@ -1,7 +1,7 @@
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
diff --git a/ggml-metal.m b/ggml-metal.m
index 0207b787..b5e9884b 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -1396,27 +1396,23 @@ static enum ggml_status ggml_metal_graph_compute(
// to the matrix-vector kernel
int ne11_mm_min = 1;

View File

@@ -1,11 +1,11 @@
diff --git a/src/llama.cpp b/src/llama.cpp
index 2b9ace28..172640e2 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -5357,16 +5357,7 @@ static void llm_load_vocab(
diff --git a/llama.cpp b/llama.cpp
index 61948751..4b72a293 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -4824,16 +4824,7 @@ static void llm_load_vocab(
// for now, only BPE models have pre-tokenizers
if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
vocab.tokenizer_add_space_prefix = false;
vocab.tokenizer_clean_spaces = true;
- if (tokenizer_pre.empty()) {
- LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
- LLAMA_LOG_WARN("%s: \n", __func__);
@@ -20,13 +20,13 @@ index 2b9ace28..172640e2 100644
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
} else if (
tokenizer_pre == "llama3" ||
@@ -5439,7 +5430,8 @@ static void llm_load_vocab(
tokenizer_pre == "jais") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS;
@@ -4888,7 +4879,8 @@ static void llm_load_vocab(
tokenizer_pre == "poro-chat") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO;
} else {
- throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
+ LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
+ vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
}
} else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
} else {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;

View File

@@ -1,7 +1,7 @@
diff --git a/src/llama.cpp b/src/llama.cpp
diff --git a/llama.cpp b/llama.cpp
index 40d2ec2c..f34eb79a 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
--- a/llama.cpp
+++ b/llama.cpp
@@ -6943,7 +6943,7 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);

View File

@@ -1,45 +0,0 @@
diff --git a/src/llama.cpp b/src/llama.cpp
index 1fe2b9f7..a43312a7 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -13689,7 +13689,7 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
const auto n_embd = hparams.n_embd;
// TODO: use a per-batch flag for logits presence instead
- const bool has_logits = !cparams.embeddings;
+ const bool has_logits = cparams.causal_attn;
const bool has_embd = lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE));
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
@@ -13959,17 +13959,25 @@ static int llama_decode_internal(
// no output
res = nullptr;
embd = nullptr;
- } else if (cparams.embeddings) {
- res = nullptr; // do not extract logits for embedding case
- embd = gf->nodes[gf->n_nodes - 1];
- if (strcmp(embd->name, "result_embd_pooled") != 0) {
- embd = gf->nodes[gf->n_nodes - 2];
+ }
+
+ if (cparams.embeddings) {
+ for (int i = gf->n_nodes - 1; i >= 0; --i) {
+ embd = gf->nodes[i];
+ if (strcmp(embd->name, "result_embd_pooled") == 0) {
+ break;
+ }
}
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
- } else {
+ } else {
embd = nullptr; // do not extract embeddings when not needed
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
}
+
+ if (!cparams.causal_attn) {
+ res = nullptr; // do not extract logits when not needed
+ }
+
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
ggml_backend_sched_alloc_graph(lctx.sched, gf);

305
llm/patches/07-gemma.diff Normal file
View File

@@ -0,0 +1,305 @@
From 5cadb45f39d001ffbad95b690d6cf0abcb4a6d96 Mon Sep 17 00:00:00 2001
From: Ollama maintainers <hello@ollama.com>
Date: Wed, 26 Jun 2024 16:18:09 -0700
Subject: [PATCH] Architecture support
---
llama.cpp | 194 +++++++++++++++++++++++++++++++++++++++++++++++++++++-
1 file changed, 193 insertions(+), 1 deletion(-)
diff --git a/llama.cpp b/llama.cpp
index 61948751..3b4196f5 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -217,6 +217,7 @@ enum llm_arch {
LLM_ARCH_INTERNLM2,
LLM_ARCH_MINICPM,
LLM_ARCH_GEMMA,
+ LLM_ARCH_GEMMA2,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_XVERSE,
@@ -255,6 +256,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_INTERNLM2, "internlm2" },
{ LLM_ARCH_MINICPM, "minicpm" },
{ LLM_ARCH_GEMMA, "gemma" },
+ { LLM_ARCH_GEMMA2, "gemma2" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_XVERSE, "xverse" },
@@ -464,10 +466,12 @@ enum llm_tensor {
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_NORM_2,
LLM_TENSOR_ATTN_OUT_NORM,
+ LLM_TENSOR_ATTN_POST_NORM,
LLM_TENSOR_ATTN_ROT_EMBD,
LLM_TENSOR_FFN_GATE_INP,
LLM_TENSOR_FFN_GATE_INP_SHEXP,
LLM_TENSOR_FFN_NORM,
+ LLM_TENSOR_FFN_POST_NORM,
LLM_TENSOR_FFN_GATE,
LLM_TENSOR_FFN_DOWN,
LLM_TENSOR_FFN_UP,
@@ -960,6 +964,24 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
+ {
+ LLM_ARCH_GEMMA2,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
+ },
+ },
{
LLM_ARCH_STARCODER2,
{
@@ -1941,6 +1963,8 @@ enum e_model {
MODEL_8x22B,
MODEL_16x12B,
MODEL_10B_128x3_66B,
+ MODEL_9B,
+ MODEL_27B,
};
static const size_t kiB = 1024;
@@ -2114,6 +2138,7 @@ struct llama_layer {
struct ggml_tensor * attn_out_norm_b;
struct ggml_tensor * attn_q_a_norm;
struct ggml_tensor * attn_kv_a_norm;
+ struct ggml_tensor * attn_post_norm;
// attention
struct ggml_tensor * wq;
@@ -2136,6 +2161,7 @@ struct llama_layer {
// normalization
struct ggml_tensor * ffn_norm;
struct ggml_tensor * ffn_norm_b;
+ struct ggml_tensor * ffn_post_norm;
struct ggml_tensor * layer_out_norm;
struct ggml_tensor * layer_out_norm_b;
struct ggml_tensor * ffn_norm_exps;
@@ -4529,6 +4555,16 @@ static void llm_load_hparams(
}
} break;
case LLM_ARCH_GEMMA:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+ switch (hparams.n_layer) {
+ case 18: model.type = e_model::MODEL_9B; break;
+ case 28: model.type = e_model::MODEL_27B; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ }
+ } break;
+ case LLM_ARCH_GEMMA2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -6305,6 +6341,40 @@ static bool llm_load_tensors(
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
}
} break;
+ case LLM_ARCH_GEMMA2:
+ {
+ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+ // output
+ model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+ model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
+
+ const int64_t n_ff = hparams.n_ff;
+ const int64_t n_embd_head_k = hparams.n_embd_head_k;
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
+
+ for (uint32_t i = 0; i < n_layer; ++i) {
+ ggml_context * ctx_layer = ctx_for_layer(i);
+ ggml_context * ctx_split = ctx_for_layer_split(i);
+
+ auto & layer = model.layers[i];
+
+ layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+
+ layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * hparams.n_head});
+ layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
+ layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
+ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * hparams.n_head, n_embd});
+ layer.attn_post_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd});
+
+ layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+ layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
+ layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
+ layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
+ layer.ffn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd});
+ }
+ } break;
case LLM_ARCH_STARCODER2:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@@ -10614,6 +10684,123 @@ struct llm_build_context {
return gf;
}
+ struct ggml_cgraph * build_gemma2() {
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+ const int64_t n_embd_head_k = hparams.n_embd_head_k;
+
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+ inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
+ cb(inpL, "inp_scaled", -1);
+
+ // inp_pos - contains the positions
+ struct ggml_tensor * inp_pos = build_inp_pos();
+
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+ for (int il = 0; il < n_layer; ++il) {
+ // norm
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "attn_norm", il);
+
+ // self-attention
+ {
+ // compute Q and K and RoPE them
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+
+ Qcur = ggml_rope_ext(
+ ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
+ n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow);
+ cb(Qcur, "Qcur", il);
+
+ Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));
+ cb(Qcur, "Qcur_scaled", il);
+
+ Kcur = ggml_rope_ext(
+ ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
+ n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow);
+ cb(Kcur, "Kcur", il);
+
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+ model.layers[il].wo, NULL,
+ Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
+ }
+
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.layers[il].attn_post_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "attn_post_norm", il);
+
+ struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
+ cb(sa_out, "sa_out", il);
+
+ cur = llm_build_norm(ctx0, sa_out, hparams,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "ffn_norm", il);
+
+ // feed-forward network
+ {
+ cur = llm_build_ffn(ctx0, cur,
+ model.layers[il].ffn_up, NULL,
+ model.layers[il].ffn_gate, NULL,
+ model.layers[il].ffn_down, NULL,
+ NULL,
+ LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
+ cb(cur, "ffn_out", il);
+ }
+
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.layers[il].ffn_post_norm, NULL,
+ LLM_NORM_RMS, cb, -1);
+ cb(cur, "ffn_post_norm", -1);
+
+ cur = ggml_add(ctx0, cur, sa_out);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, cb, -1);
+ cb(cur, "result_norm", -1);
+
+ // lm_head
+ cur = ggml_mul_mat(ctx0, model.output, cur);
+ cb(cur, "result_output", -1);
+
+ ggml_build_forward_expand(gf, cur);
+
+ return gf;
+ }
+
struct ggml_cgraph * build_starcoder2() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
@@ -11847,6 +12034,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_gemma();
} break;
+ case LLM_ARCH_GEMMA2:
+ {
+ result = llm.build_gemma2();
+ } break;
case LLM_ARCH_STARCODER2:
{
result = llm.build_starcoder2();
@@ -16671,6 +16862,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_PHI2:
case LLM_ARCH_PHI3:
case LLM_ARCH_GEMMA:
+ case LLM_ARCH_GEMMA2:
case LLM_ARCH_STARCODER2:
case LLM_ARCH_GPTNEOX:
return LLAMA_ROPE_TYPE_NEOX;
@@ -18551,7 +18743,7 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<s>assistant\n";
}
- } else if (tmpl == "gemma" || tmpl.find("<start_of_turn>") != std::string::npos) {
+ } else if (tmpl == "gemma" || tmpl == "gemma2" || tmpl.find("<start_of_turn>") != std::string::npos) {
// google/gemma-7b-it
std::string system_prompt = "";
for (auto message : chat) {
--
2.45.2

View File

@@ -1,42 +0,0 @@
diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
index 95fbe3d0..5a02a6ec 100644
--- a/examples/llava/clip.cpp
+++ b/examples/llava/clip.cpp
@@ -32,6 +33,14 @@
#include <cinttypes>
#include <limits>
+#if defined(_WIN32)
+#define WIN32_LEAN_AND_MEAN
+#ifndef NOMINMAX
+ #define NOMINMAX
+#endif
+#include <windows.h>
+#endif
+
//#define CLIP_DEBUG_FUNCTIONS
// RGB uint8 image
@@ -1055,7 +1064,22 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
return nullptr;
}
+#ifdef _WIN32
+ int wlen = MultiByteToWideChar(CP_UTF8, 0, fname, -1, NULL, 0);
+ if (!wlen) {
+ return NULL;
+ }
+ wchar_t * wbuf = (wchar_t *) malloc(wlen * sizeof(wchar_t));
+ wlen = MultiByteToWideChar(CP_UTF8, 0, fname, -1, wbuf, wlen);
+ if (!wlen) {
+ free(wbuf);
+ return NULL;
+ }
+ auto fin = std::ifstream(wbuf, std::ios::binary);
+ free(wbuf);
+#else
auto fin = std::ifstream(fname, std::ios::binary);
+#endif
if (!fin) {
LOG_TEE("cannot open model file for loading tensors\n");
clip_free(new_clip);

View File

@@ -1,60 +0,0 @@
diff --git a/src/llama.cpp b/src/llama.cpp
index 721b8f4e..cfe7ac40 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -8420,14 +8420,14 @@ struct llm_build_context {
}
struct ggml_tensor * build_inp_mean() {
- lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
+ lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, cparams.n_seq_max);
cb(lctx.inp_mean, "inp_mean", -1);
ggml_set_input(lctx.inp_mean);
return lctx.inp_mean;
}
struct ggml_tensor * build_inp_cls() {
- lctx.inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+ lctx.inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_seq_max);
cb(lctx.inp_cls, "inp_cls", -1);
ggml_set_input(lctx.inp_cls);
return lctx.inp_cls;
@@ -13847,19 +13847,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
float * data = (float *) lctx.inp_mean->data;
- memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean));
+ memset(lctx.inp_mean->data, 0, n_tokens * cparams.n_seq_max * ggml_element_size(lctx.inp_mean));
std::vector<uint64_t> sum(n_tokens, 0);
for (int i = 0; i < n_tokens; ++i) {
const llama_seq_id seq_id = batch.seq_id[i][0];
-
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
-
sum[seq_id] += 1;
}
- std::vector<float> div(n_tokens, 0.0f);
- for (int i = 0; i < n_tokens; ++i) {
+ std::vector<float> div(cparams.n_seq_max, 0.0f);
+ for (uint32_t i = 0; i < cparams.n_seq_max; ++i) {
const uint64_t s = sum[i];
if (s > 0) {
div[i] = 1.0f/float(s);
@@ -13879,14 +13876,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
uint32_t * data = (uint32_t *) lctx.inp_cls->data;
- memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
+ memset(lctx.inp_cls->data, 0, cparams.n_seq_max * ggml_element_size(lctx.inp_cls));
for (int i = 0; i < n_tokens; ++i) {
const llama_seq_id seq_id = batch.seq_id[i][0];
const llama_pos pos = batch.pos[i];
-
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS");
-
if (pos == 0) {
data[seq_id] = i;
}

View File

@@ -38,7 +38,7 @@ func Init() error {
}
var variants []string
for v := range getAvailableServers() {
for v := range availableServers() {
variants = append(variants, v)
}
slog.Info(fmt.Sprintf("Dynamic LLM libraries %v", variants))
@@ -50,7 +50,7 @@ func Init() error {
// binary names may contain an optional variant separated by '_'
// For example, "ollama_rocm_v6" and "ollama_rocm_v5" or "ollama_cpu" and "ollama_cpu_avx2"
// Any library without a variant is the lowest common denominator
func getAvailableServers() map[string]string {
func availableServers() map[string]string {
payloadsDir, err := gpu.PayloadsDir()
if err != nil {
slog.Error("payload lookup error", "error", err)
@@ -80,7 +80,7 @@ func getAvailableServers() map[string]string {
// TODO - switch to metadata based mapping
func serversForGpu(info gpu.GpuInfo) []string {
// glob workDir for files that start with ollama_
availableServers := getAvailableServers()
availableServers := availableServers()
requested := info.Library
if info.Variant != gpu.CPUCapabilityNone {
requested += "_" + info.Variant.String()
@@ -115,29 +115,27 @@ func serversForGpu(info gpu.GpuInfo) []string {
servers = append(servers, alt...)
}
if !(runtime.GOOS == "darwin" && runtime.GOARCH == "arm64") {
// Load up the best CPU variant if not primary requested
if info.Library != "cpu" {
variant := gpu.GetCPUCapability()
// If no variant, then we fall back to default
// If we have a variant, try that if we find an exact match
// Attempting to run the wrong CPU instructions will panic the
// process
if variant != gpu.CPUCapabilityNone {
for cmp := range availableServers {
if cmp == "cpu_"+variant.String() {
servers = append(servers, cmp)
break
}
// Load up the best CPU variant if not primary requested
if info.Library != "cpu" {
variant := gpu.GetCPUCapability()
// If no variant, then we fall back to default
// If we have a variant, try that if we find an exact match
// Attempting to run the wrong CPU instructions will panic the
// process
if variant != gpu.CPUCapabilityNone {
for cmp := range availableServers {
if cmp == "cpu_"+variant.String() {
servers = append(servers, cmp)
break
}
} else {
servers = append(servers, "cpu")
}
} else {
servers = append(servers, "cpu")
}
}
if len(servers) == 0 {
servers = []string{"cpu"}
}
if len(servers) == 0 {
servers = []string{"cpu"}
}
return servers
@@ -149,7 +147,7 @@ func serverForCpu() string {
return "metal"
}
variant := gpu.GetCPUCapability()
availableServers := getAvailableServers()
availableServers := availableServers()
if variant != gpu.CPUCapabilityNone {
for cmp := range availableServers {
if cmp == "cpu_"+variant.String() {

View File

@@ -33,7 +33,7 @@ type LlamaServer interface {
Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embed(ctx context.Context, input []string) ([][]float32, error)
Embedding(ctx context.Context, prompt string) ([]float64, error)
Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error
@@ -88,7 +88,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
var estimate MemoryEstimate
var systemTotalMemory uint64
var systemFreeMemory uint64
var systemSwapFreeMemory uint64
systemMemInfo, err := gpu.GetCPUMem()
if err != nil {
@@ -96,8 +95,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
} else {
systemTotalMemory = systemMemInfo.TotalMemory
systemFreeMemory = systemMemInfo.FreeMemory
systemSwapFreeMemory = systemMemInfo.FreeSwap
slog.Debug("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory))
slog.Debug("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", systemFreeMemory)
}
// If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info
@@ -124,16 +122,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
}
}
// On linux, over-allocating CPU memory will almost always result in an error
if runtime.GOOS == "linux" {
systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize
available := systemFreeMemory + systemSwapFreeMemory
if systemMemoryRequired > available {
slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", available, "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "swap", format.HumanBytes2(systemSwapFreeMemory))
return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available))
}
}
estimate.log()
// Loop through potential servers
@@ -143,20 +131,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
}
availableServers := getAvailableServers()
if len(availableServers) == 0 {
if runtime.GOOS != "windows" {
slog.Warn("llama server binary disappeared, reinitializing payloads")
err = Init()
if err != nil {
slog.Warn("failed to reinitialize payloads", "error", err)
return nil, err
}
availableServers = getAvailableServers()
} else {
return nil, finalErr
}
}
availableServers := availableServers()
var servers []string
if cpuRunner != "" {
servers = []string{cpuRunner}
@@ -233,8 +208,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
if g.Library == "metal" &&
uint64(opts.NumGPU) > 0 &&
uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
opts.UseMMap = new(bool)
*opts.UseMMap = false
opts.UseMMap = api.TriStateFalse
}
}
@@ -245,10 +219,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
// Windows CUDA should not use mmap for best performance
// Linux with a model larger than free space, mmap leads to thrashing
// For CPU loads we want the memory to be allocated, not FS cache
if (runtime.GOOS == "windows" && gpus[0].Library == "cuda" && opts.UseMMap == nil) ||
(runtime.GOOS == "linux" && systemFreeMemory < estimate.TotalSize && opts.UseMMap == nil) ||
(gpus[0].Library == "cpu" && opts.UseMMap == nil) ||
(opts.UseMMap != nil && !*opts.UseMMap) {
if (runtime.GOOS == "windows" && gpus[0].Library == "cuda" && opts.UseMMap == api.TriStateUndefined) ||
(runtime.GOOS == "linux" && systemFreeMemory < estimate.TotalSize && opts.UseMMap == api.TriStateUndefined) ||
(gpus[0].Library == "cpu" && opts.UseMMap == api.TriStateUndefined) ||
opts.UseMMap == api.TriStateFalse {
params = append(params, "--no-mmap")
}
@@ -266,6 +240,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--tensor-split", estimate.TensorSplit)
}
if estimate.TensorSplit != "" {
params = append(params, "--tensor-split", estimate.TensorSplit)
}
for i := range len(servers) {
dir := availableServers[servers[i]]
if dir == "" {
@@ -582,9 +560,6 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
if s.status != nil && s.status.LastErrMsg != "" {
msg = s.status.LastErrMsg
}
if strings.Contains(msg, "unknown model") {
return fmt.Errorf("this model is not supported by your version of Ollama. You may need to upgrade")
}
return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
default:
}
@@ -687,7 +662,7 @@ type CompletionRequest struct {
Prompt string
Format string
Images []ImageData
Options *api.Options
Options api.Options
}
type CompletionResponse struct {
@@ -707,9 +682,10 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
}
defer s.sem.Release(1)
// put an upper limit on num_predict to avoid the model running on forever
// only allow maximum 10 "context shifts" to avoid infinite generation
if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
req.Options.NumPredict = 10 * s.options.NumCtx
slog.Debug("setting token limit to 10x num_ctx", "num_ctx", s.options.NumCtx, "num_predict", req.Options.NumPredict)
}
request := map[string]any{
@@ -867,15 +843,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return nil
}
type EmbedRequest struct {
Content []string `json:"content"`
type EmbeddingRequest struct {
Content string `json:"content"`
}
type EmbedResponse struct {
Embedding [][]float32 `json:"embedding"`
type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"`
}
func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) {
func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return nil, err
@@ -890,7 +866,7 @@ func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, err
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(EmbedRequest{Content: input})
data, err := json.Marshal(TokenizeRequest{Content: prompt})
if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err)
}
@@ -917,7 +893,7 @@ func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, err
return nil, fmt.Errorf("%s", body)
}
var embedding EmbedResponse
var embedding EmbeddingResponse
if err := json.Unmarshal(body, &embedding); err != nil {
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
}

View File

@@ -25,7 +25,6 @@ var errorPrefixes = []string{
"CUDA error",
"cudaMalloc failed",
"\"ERR\"",
"error loading model",
}
func (w *StatusWriter) Write(b []byte) (int, error) {

View File

@@ -3,18 +3,15 @@ package openai
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math/rand"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
type Error struct {
@@ -30,7 +27,7 @@ type ErrorResponse struct {
type Message struct {
Role string `json:"role"`
Content any `json:"content"`
Content string `json:"content"`
}
type Choice struct {
@@ -45,12 +42,6 @@ type ChunkChoice struct {
FinishReason *string `json:"finish_reason"`
}
type CompleteChunkChoice struct {
Text string `json:"text"`
Index int `json:"index"`
FinishReason *string `json:"finish_reason"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
@@ -61,11 +52,6 @@ type ResponseFormat struct {
Type string `json:"type"`
}
type EmbedRequest struct {
Input any `json:"input"`
Model string `json:"model"`
}
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
@@ -99,63 +85,6 @@ type ChatCompletionChunk struct {
Choices []ChunkChoice `json:"choices"`
}
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
type CompletionRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
FrequencyPenalty float32 `json:"frequency_penalty"`
MaxTokens *int `json:"max_tokens"`
PresencePenalty float32 `json:"presence_penalty"`
Seed *int `json:"seed"`
Stop any `json:"stop"`
Stream bool `json:"stream"`
Temperature *float32 `json:"temperature"`
TopP float32 `json:"top_p"`
}
type Completion struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []CompleteChunkChoice `json:"choices"`
Usage Usage `json:"usage,omitempty"`
}
type CompletionChunk struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []CompleteChunkChoice `json:"choices"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
}
type Model struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
type Embedding struct {
Object string `json:"object"`
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
}
type ListCompletion struct {
Object string `json:"object"`
Data []Model `json:"data"`
}
type EmbeddingList struct {
Object string `json:"object"`
Data []Embedding `json:"data"`
Model string `json:"model"`
}
func NewError(code int, message string) ErrorResponse {
var etype string
switch code {
@@ -216,159 +145,10 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
}
}
func toCompletion(id string, r api.GenerateResponse) Completion {
return Completion{
Id: id,
Object: "text_completion",
Created: r.CreatedAt.Unix(),
Model: r.Model,
SystemFingerprint: "fp_ollama",
Choices: []CompleteChunkChoice{{
Text: r.Response,
Index: 0,
FinishReason: func(reason string) *string {
if len(reason) > 0 {
return &reason
}
return nil
}(r.DoneReason),
}},
Usage: Usage{
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
},
}
}
func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
return CompletionChunk{
Id: id,
Object: "text_completion",
Created: time.Now().Unix(),
Model: r.Model,
SystemFingerprint: "fp_ollama",
Choices: []CompleteChunkChoice{{
Text: r.Response,
Index: 0,
FinishReason: func(reason string) *string {
if len(reason) > 0 {
return &reason
}
return nil
}(r.DoneReason),
}},
}
}
func toListCompletion(r api.ListResponse) ListCompletion {
var data []Model
for _, m := range r.Models {
data = append(data, Model{
Id: m.Name,
Object: "model",
Created: m.ModifiedAt.Unix(),
OwnedBy: model.ParseName(m.Name).Namespace,
})
}
return ListCompletion{
Object: "list",
Data: data,
}
}
func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
if r.Embeddings != nil {
var data []Embedding
for i, e := range r.Embeddings {
data = append(data, Embedding{
Object: "embedding",
Embedding: e,
Index: i,
})
}
return EmbeddingList{
Object: "list",
Data: data,
Model: model,
}
}
return EmbeddingList{}
}
func toModel(r api.ShowResponse, m string) Model {
return Model{
Id: m,
Object: "model",
Created: r.ModifiedAt.Unix(),
OwnedBy: model.ParseName(m).Namespace,
}
}
func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
func fromRequest(r ChatCompletionRequest) api.ChatRequest {
var messages []api.Message
for _, msg := range r.Messages {
switch content := msg.Content.(type) {
case string:
messages = append(messages, api.Message{Role: msg.Role, Content: content})
case []any:
message := api.Message{Role: msg.Role}
for _, c := range content {
data, ok := c.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid message format")
}
switch data["type"] {
case "text":
text, ok := data["text"].(string)
if !ok {
return nil, fmt.Errorf("invalid message format")
}
message.Content = text
case "image_url":
var url string
if urlMap, ok := data["image_url"].(map[string]any); ok {
if url, ok = urlMap["url"].(string); !ok {
return nil, fmt.Errorf("invalid message format")
}
} else {
if url, ok = data["image_url"].(string); !ok {
return nil, fmt.Errorf("invalid message format")
}
}
types := []string{"jpeg", "jpg", "png"}
valid := false
for _, t := range types {
prefix := "data:image/" + t + ";base64,"
if strings.HasPrefix(url, prefix) {
url = strings.TrimPrefix(url, prefix)
valid = true
break
}
}
if !valid {
return nil, fmt.Errorf("invalid image input")
}
img, err := base64.StdEncoding.DecodeString(url)
if err != nil {
return nil, fmt.Errorf("invalid message format")
}
message.Images = append(message.Images, img)
default:
return nil, fmt.Errorf("invalid message format")
}
}
messages = append(messages, message)
default:
return nil, fmt.Errorf("invalid message content type: %T", content)
}
messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
}
options := make(map[string]interface{})
@@ -376,7 +156,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
switch stop := r.Stop.(type) {
case string:
options["stop"] = []string{stop}
case []any:
case []interface{}:
var stops []string
for _, s := range stop {
if str, ok := s.(string); ok {
@@ -419,96 +199,22 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
format = "json"
}
return &api.ChatRequest{
return api.ChatRequest{
Model: r.Model,
Messages: messages,
Format: format,
Options: options,
Stream: &r.Stream,
}, nil
}
}
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
options := make(map[string]any)
switch stop := r.Stop.(type) {
case string:
options["stop"] = []string{stop}
case []any:
var stops []string
for _, s := range stop {
if str, ok := s.(string); ok {
stops = append(stops, str)
} else {
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
}
}
options["stop"] = stops
}
if r.MaxTokens != nil {
options["num_predict"] = *r.MaxTokens
}
if r.Temperature != nil {
options["temperature"] = *r.Temperature * 2.0
} else {
options["temperature"] = 1.0
}
if r.Seed != nil {
options["seed"] = *r.Seed
}
options["frequency_penalty"] = r.FrequencyPenalty * 2.0
options["presence_penalty"] = r.PresencePenalty * 2.0
if r.TopP != 0.0 {
options["top_p"] = r.TopP
} else {
options["top_p"] = 1.0
}
return api.GenerateRequest{
Model: r.Model,
Prompt: r.Prompt,
Options: options,
Stream: &r.Stream,
}, nil
}
type BaseWriter struct {
type writer struct {
stream bool
id string
gin.ResponseWriter
}
type ChatWriter struct {
stream bool
id string
BaseWriter
}
type CompleteWriter struct {
stream bool
id string
BaseWriter
}
type ListWriter struct {
BaseWriter
}
type RetrieveWriter struct {
BaseWriter
model string
}
type EmbedWriter struct {
BaseWriter
model string
}
func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
func (w *writer) writeError(code int, data []byte) (int, error) {
var serr api.StatusError
err := json.Unmarshal(data, &serr)
if err != nil {
@@ -524,7 +230,7 @@ func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
return len(data), nil
}
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
func (w *writer) writeResponse(data []byte) (int, error) {
var chatResponse api.ChatResponse
err := json.Unmarshal(data, &chatResponse)
if err != nil {
@@ -564,7 +270,7 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
return len(data), nil
}
func (w *ChatWriter) Write(data []byte) (int, error) {
func (w *writer) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(code, data)
@@ -573,244 +279,7 @@ func (w *ChatWriter) Write(data []byte) (int, error) {
return w.writeResponse(data)
}
func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
var generateResponse api.GenerateResponse
err := json.Unmarshal(data, &generateResponse)
if err != nil {
return 0, err
}
// completion chunk
if w.stream {
d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
if generateResponse.Done {
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil {
return 0, err
}
}
return len(data), nil
}
// completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *CompleteWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(code, data)
}
return w.writeResponse(data)
}
func (w *ListWriter) writeResponse(data []byte) (int, error) {
var listResponse api.ListResponse
err := json.Unmarshal(data, &listResponse)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ListWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(code, data)
}
return w.writeResponse(data)
}
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
var showResponse api.ShowResponse
err := json.Unmarshal(data, &showResponse)
if err != nil {
return 0, err
}
// retrieve completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *RetrieveWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(code, data)
}
return w.writeResponse(data)
}
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
var embedResponse api.EmbedResponse
err := json.Unmarshal(data, &embedResponse)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *EmbedWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(code, data)
}
return w.writeResponse(data)
}
func ListMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
w := &ListWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
}
c.Writer = w
c.Next()
}
}
func RetrieveMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
// response writer
w := &RetrieveWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
model: c.Param("model"),
}
c.Writer = w
c.Next()
}
}
func CompletionsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req CompletionRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
var b bytes.Buffer
genReq, err := fromCompleteRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &CompleteWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
}
c.Writer = w
c.Next()
}
}
func EmbeddingsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req EmbedRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Input == "" {
req.Input = []string{""}
}
if req.Input == nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
return
}
if v, ok := req.Input.([]any); ok && len(v) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &EmbedWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
model: req.Model,
}
c.Writer = w
c.Next()
}
}
func ChatMiddleware() gin.HandlerFunc {
func Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req ChatCompletionRequest
err := c.ShouldBindJSON(&req)
@@ -825,23 +294,17 @@ func ChatMiddleware() gin.HandlerFunc {
}
var b bytes.Buffer
chatReq, err := fromChatRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
}
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
if err := json.NewEncoder(&b).Encode(fromRequest(req)); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &ChatWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
w := &writer{
ResponseWriter: c.Writer,
stream: req.Stream,
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
}
c.Writer = w

View File

@@ -1,392 +0,0 @@
package openai
import (
"bytes"
"encoding/base64"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/stretchr/testify/assert"
)
const prefix = `data:image/jpeg;base64,`
const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
const imageURL = prefix + image
func TestMiddlewareRequests(t *testing.T) {
type testCase struct {
Name string
Method string
Path string
Handler func() gin.HandlerFunc
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *http.Request)
}
var capturedRequest *http.Request
captureRequestMiddleware := func() gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
capturedRequest = c.Request
c.Next()
}
}
testCases := []testCase{
{
Name: "chat handler",
Method: http.MethodPost,
Path: "/api/chat",
Handler: ChatMiddleware,
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: "Hello"}},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var chatReq api.ChatRequest
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
t.Fatal(err)
}
if chatReq.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
}
if chatReq.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
}
},
},
{
Name: "completions handler",
Method: http.MethodPost,
Path: "/api/generate",
Handler: CompletionsMiddleware,
Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: &temp,
Stop: []string{"\n", "stop"},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var genReq api.GenerateRequest
if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
t.Fatal(err)
}
if genReq.Prompt != "Hello" {
t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
}
if genReq.Options["temperature"] != 1.6 {
t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
}
stopTokens, ok := genReq.Options["stop"].([]any)
if !ok {
t.Fatalf("expected stop tokens to be a list")
}
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
}
},
},
{
Name: "chat handler with image content",
Method: http.MethodPost,
Path: "/api/chat",
Handler: ChatMiddleware,
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{
{
Role: "user", Content: []map[string]any{
{"type": "text", "text": "Hello"},
{"type": "image_url", "image_url": map[string]string{"url": imageURL}},
},
},
},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var chatReq api.ChatRequest
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
t.Fatal(err)
}
if chatReq.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
}
if chatReq.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
}
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
if !bytes.Equal(chatReq.Messages[0].Images[0], img) {
t.Fatalf("expected image encoding, got %s", chatReq.Messages[0].Images[0])
}
},
},
{
Name: "embed handler single input",
Method: http.MethodPost,
Path: "/api/embed",
Handler: EmbeddingsMiddleware,
Setup: func(t *testing.T, req *http.Request) {
body := EmbedRequest{
Input: "Hello",
Model: "test-model",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var embedReq api.EmbedRequest
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
t.Fatal(err)
}
if embedReq.Input != "Hello" {
t.Fatalf("expected 'Hello', got %s", embedReq.Input)
}
if embedReq.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", embedReq.Model)
}
},
},
{
Name: "embed handler batch input",
Method: http.MethodPost,
Path: "/api/embed",
Handler: EmbeddingsMiddleware,
Setup: func(t *testing.T, req *http.Request) {
body := EmbedRequest{
Input: []string{"Hello", "World"},
Model: "test-model",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var embedReq api.EmbedRequest
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
t.Fatal(err)
}
input, ok := embedReq.Input.([]any)
if !ok {
t.Fatalf("expected input to be a list")
}
if input[0].(string) != "Hello" {
t.Fatalf("expected 'Hello', got %s", input[0])
}
if input[1].(string) != "World" {
t.Fatalf("expected 'World', got %s", input[1])
}
if embedReq.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", embedReq.Model)
}
},
},
}
gin.SetMode(gin.TestMode)
router := gin.New()
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
router = gin.New()
router.Use(captureRequestMiddleware())
router.Use(tc.Handler())
router.Handle(tc.Method, tc.Path, endpoint)
req, _ := http.NewRequest(tc.Method, tc.Path, nil)
if tc.Setup != nil {
tc.Setup(t, req)
}
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest)
})
}
}
func TestMiddlewareResponses(t *testing.T) {
type testCase struct {
Name string
Method string
Path string
TestPath string
Handler func() gin.HandlerFunc
Endpoint func(c *gin.Context)
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, resp *httptest.ResponseRecorder)
}
testCases := []testCase{
{
Name: "completions handler error forwarding",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
},
Setup: func(t *testing.T, req *http.Request) {
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), `"invalid request"`) {
t.Fatalf("error was not forwarded")
}
},
},
{
Name: "list handler",
Method: http.MethodGet,
Path: "/api/tags",
TestPath: "/api/tags",
Handler: ListMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ListResponse{
Models: []api.ListModelResponse{
{
Name: "Test Model",
},
},
})
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var listResp ListCompletion
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
t.Fatal(err)
}
if listResp.Object != "list" {
t.Fatalf("expected list, got %s", listResp.Object)
}
if len(listResp.Data) != 1 {
t.Fatalf("expected 1, got %d", len(listResp.Data))
}
if listResp.Data[0].Id != "Test Model" {
t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
}
},
},
{
Name: "retrieve model",
Method: http.MethodGet,
Path: "/api/show/:model",
TestPath: "/api/show/test-model",
Handler: RetrieveMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ShowResponse{
ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
})
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
var retrieveResp Model
if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
t.Fatal(err)
}
if retrieveResp.Object != "model" {
t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
}
if retrieveResp.Id != "test-model" {
t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
}
},
},
}
gin.SetMode(gin.TestMode)
router := gin.New()
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
router = gin.New()
router.Use(tc.Handler())
router.Handle(tc.Method, tc.Path, tc.Endpoint)
req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
if tc.Setup != nil {
tc.Setup(t, req)
}
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
tc.Expected(t, resp)
})
}
}

View File

@@ -124,7 +124,7 @@ func ParseFile(r io.Reader) (*File, error) {
case stateComment, stateNil:
// pass
case stateValue:
s, ok := unquote(strings.TrimSpace(b.String()))
s, ok := unquote(b.String())
if !ok || isSpace(r) {
if _, err := b.WriteRune(r); err != nil {
return nil, err
@@ -158,7 +158,7 @@ func ParseFile(r io.Reader) (*File, error) {
case stateComment, stateNil:
// pass; nothing to flush
case stateValue:
s, ok := unquote(strings.TrimSpace(b.String()))
s, ok := unquote(b.String())
if !ok {
return nil, io.ErrUnexpectedEOF
}

View File

@@ -22,13 +22,7 @@ ADAPTER adapter1
LICENSE MIT
PARAMETER param1 value1
PARAMETER param2 value2
TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
{{ .Response }}<|eot_id|>"""
TEMPLATE template1
`
reader := strings.NewReader(input)
@@ -42,40 +36,7 @@ TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
{Name: "license", Args: "MIT"},
{Name: "param1", Args: "value1"},
{Name: "param2", Args: "value2"},
{Name: "template", Args: "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|>"},
}
assert.Equal(t, expectedCommands, modelfile.Commands)
}
func TestParseFileTrimSpace(t *testing.T) {
input := `
FROM " model 1"
ADAPTER adapter3
LICENSE "MIT "
PARAMETER param1 value1
PARAMETER param2 value2
TEMPLATE """ {{ if .System }}<|start_header_id|>system<|end_header_id|>
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
{{ .Response }}<|eot_id|> """
`
reader := strings.NewReader(input)
modelfile, err := ParseFile(reader)
require.NoError(t, err)
expectedCommands := []Command{
{Name: "model", Args: " model 1"},
{Name: "adapter", Args: "adapter3"},
{Name: "license", Args: "MIT "},
{Name: "param1", Args: "value1"},
{Name: "param2", Args: "value2"},
{Name: "template", Args: " {{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|> "},
{Name: "template", Args: "template1"},
}
assert.Equal(t, expectedCommands, modelfile.Commands)
@@ -87,26 +48,6 @@ func TestParseFileFrom(t *testing.T) {
expected []Command
err error
}{
{
"FROM \"FOO BAR \"",
[]Command{{Name: "model", Args: "FOO BAR "}},
nil,
},
{
"FROM \"FOO BAR\"\nPARAMETER param1 value1",
[]Command{{Name: "model", Args: "FOO BAR"}, {Name: "param1", Args: "value1"}},
nil,
},
{
"FROM FOOO BAR ",
[]Command{{Name: "model", Args: "FOOO BAR"}},
nil,
},
{
"FROM /what/is/the path ",
[]Command{{Name: "model", Args: "/what/is/the path"}},
nil,
},
{
"FROM foo",
[]Command{{Name: "model", Args: "foo"}},
@@ -145,11 +86,6 @@ func TestParseFileFrom(t *testing.T) {
[]Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}},
nil,
},
{
"PARAMETER what the \nFROM lemons make lemonade ",
[]Command{{Name: "what", Args: "the"}, {Name: "model", Args: "lemons make lemonade"}},
nil,
},
}
for _, c := range cases {
@@ -463,7 +399,7 @@ func TestParseFileParameters(t *testing.T) {
"mirostat_eta 1.0": {"mirostat_eta", "1.0"},
"penalize_newline true": {"penalize_newline", "true"},
"stop ### User:": {"stop", "### User:"},
"stop ### User: ": {"stop", "### User:"},
"stop ### User: ": {"stop", "### User: "},
"stop \"### User:\"": {"stop", "### User:"},
"stop \"### User: \"": {"stop", "### User: "},
"stop \"\"\"### User:\"\"\"": {"stop", "### User:"},

View File

@@ -107,12 +107,9 @@ function gatherDependencies() {
# TODO - this varies based on host build system and MSVC version - drive from dumpbin output
# currently works for Win11 + MSVC 2019 + Cuda V11
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140*.dll" "${script:DEPS_DIR}\ollama_runners\"
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140.dll" "${script:DEPS_DIR}\ollama_runners\"
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\ollama_runners\"
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\ollama_runners\"
foreach ($part in $("runtime", "stdio", "filesystem", "math", "convert", "heap", "string", "time", "locale", "environment")) {
cp "$env:VCToolsRedistDir\..\..\..\Tools\Llvm\x64\bin\api-ms-win-crt-${part}*.dll" "${script:DEPS_DIR}\ollama_runners\"
}
cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\"

View File

@@ -6,21 +6,10 @@ set -ex
MACHINE=$(uname -m)
if grep -i "centos" /etc/system-release >/dev/null; then
# As of 7/1/2024 mirrorlist.centos.org has been taken offline, so adjust accordingly
sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo
sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo
sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo
# Centos 7 derivatives have too old of a git version to run our generate script
# uninstall and ignore failures
yum remove -y git
yum -y install epel-release centos-release-scl
# The release packages reinstate the mirrors, undo that again
sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo
sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo
sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo
yum -y install dnf
if [ "${MACHINE}" = "x86_64" ]; then
yum -y install https://repo.ius.io/ius-release-el7.rpm

View File

@@ -28,27 +28,11 @@ import (
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
var (
errCapabilities = errors.New("does not support")
errCapabilityCompletion = errors.New("completion")
errCapabilityTools = errors.New("tools")
errCapabilityInsert = errors.New("insert")
)
type Capability string
const (
CapabilityCompletion = Capability("completion")
CapabilityTools = Capability("tools")
CapabilityInsert = Capability("insert")
)
type registryOptions struct {
Insecure bool
Username string
@@ -64,59 +48,16 @@ type Model struct {
ParentModel string
AdapterPaths []string
ProjectorPaths []string
Template string
System string
License []string
Digest string
Options map[string]interface{}
Messages []Message
Template *template.Template
}
// CheckCapabilities checks if the model has the specified capabilities returning an error describing
// any missing or unknown capabilities
func (m *Model) CheckCapabilities(caps ...Capability) error {
var errs []error
for _, cap := range caps {
switch cap {
case CapabilityCompletion:
f, err := os.Open(m.ModelPath)
if err != nil {
slog.Error("couldn't open model file", "error", err)
continue
}
defer f.Close()
// TODO(mxyng): decode the GGML into model to avoid doing this multiple times
ggml, _, err := llm.DecodeGGML(f, 0)
if err != nil {
slog.Error("couldn't decode ggml", "error", err)
continue
}
if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
errs = append(errs, errCapabilityCompletion)
}
case CapabilityTools:
if !slices.Contains(m.Template.Vars(), "tools") {
errs = append(errs, errCapabilityTools)
}
case CapabilityInsert:
vars := m.Template.Vars()
if !slices.Contains(vars, "suffix") {
errs = append(errs, errCapabilityInsert)
}
default:
slog.Error("unknown capability", "capability", cap)
return fmt.Errorf("unknown capability: %s", cap)
}
}
if err := errors.Join(errs...); err != nil {
return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
}
return nil
func (m *Model) IsEmbedding() bool {
return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert")
}
func (m *Model) String() string {
@@ -141,10 +82,10 @@ func (m *Model) String() string {
})
}
if m.Template != nil {
if m.Template != "" {
modelfile.Commands = append(modelfile.Commands, parser.Command{
Name: "template",
Args: m.Template.String(),
Args: m.Template,
})
}
@@ -194,6 +135,13 @@ type Message struct {
Content string `json:"content"`
}
type ManifestV2 struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config *Layer `json:"config"`
Layers []*Layer `json:"layers"`
}
type ConfigV2 struct {
ModelFormat string `json:"model_format"`
ModelFamily string `json:"model_family"`
@@ -212,7 +160,7 @@ type RootFS struct {
DiffIDs []string `json:"diff_ids"`
}
func GetManifest(mp ModelPath) (*Manifest, string, error) {
func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
fp, err := mp.GetManifestPath()
if err != nil {
return nil, "", err
@@ -222,7 +170,7 @@ func GetManifest(mp ModelPath) (*Manifest, string, error) {
return nil, "", err
}
var manifest *Manifest
var manifest *ManifestV2
bts, err := os.ReadFile(fp)
if err != nil {
@@ -250,7 +198,8 @@ func GetModel(name string) (*Model, error) {
Name: mp.GetFullTagname(),
ShortName: mp.GetShortTagname(),
Digest: digest,
Template: template.DefaultTemplate,
Template: "{{ .Prompt }}",
License: []string{},
}
filename, err := GetBlobsPath(manifest.Config.Digest)
@@ -286,17 +235,13 @@ func GetModel(name string) (*Model, error) {
model.AdapterPaths = append(model.AdapterPaths, filename)
case "application/vnd.ollama.image.projector":
model.ProjectorPaths = append(model.ProjectorPaths, filename)
case "application/vnd.ollama.image.prompt",
"application/vnd.ollama.image.template":
case "application/vnd.ollama.image.template":
bts, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
model.Template, err = template.Parse(string(bts))
if err != nil {
return nil, err
}
model.Template = string(bts)
case "application/vnd.ollama.image.system":
bts, err := os.ReadFile(filename)
if err != nil {
@@ -304,6 +249,13 @@ func GetModel(name string) (*Model, error) {
}
model.System = string(bts)
case "application/vnd.ollama.image.prompt":
bts, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
model.Template = string(bts)
case "application/vnd.ollama.image.params":
params, err := os.Open(filename)
if err != nil {
@@ -870,7 +822,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
var manifest *Manifest
var manifest *ManifestV2
var err error
var noprune string
@@ -977,7 +929,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
return nil
}
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*ManifestV2, error) {
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
headers := make(http.Header)
@@ -988,7 +940,7 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptio
}
defer resp.Body.Close()
var m *Manifest
var m *ManifestV2
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
return nil, err
}

View File

@@ -14,10 +14,7 @@ import (
)
type Manifest struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config *Layer `json:"config"`
Layers []*Layer `json:"layers"`
ManifestV2
filepath string
fi os.FileInfo
@@ -69,7 +66,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
p := filepath.Join(manifests, n.Filepath())
var m Manifest
var m ManifestV2
f, err := os.Open(p)
if err != nil {
return nil, err
@@ -86,11 +83,12 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
return nil, err
}
m.filepath = p
m.fi = fi
m.digest = fmt.Sprintf("%x", sha256sum.Sum(nil))
return &m, nil
return &Manifest{
ManifestV2: m,
filepath: p,
fi: fi,
digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
}, nil
}
func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
@@ -110,7 +108,7 @@ func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
}
defer f.Close()
m := Manifest{
m := ManifestV2{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: config,

View File

@@ -25,7 +25,7 @@ func createManifest(t *testing.T, path, name string) {
}
defer f.Close()
if err := json.NewEncoder(f).Encode(Manifest{}); err != nil {
if err := json.NewEncoder(f).Encode(ManifestV2{}); err != nil {
t.Fatal(err)
}
}

View File

@@ -4,7 +4,6 @@ import (
"archive/zip"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
@@ -12,14 +11,12 @@ import (
"net/http"
"os"
"path/filepath"
"slices"
"strings"
"text/template/parse"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/templates"
"github.com/ollama/ollama/types/model"
)
@@ -94,11 +91,12 @@ func extractFromZipFile(p string, file *os.File, fn func(api.ProgressResponse))
fn(api.ProgressResponse{Status: "unpacking model metadata"})
for _, f := range r.File {
if !filepath.IsLocal(f.Name) {
return fmt.Errorf("%w: %s", zip.ErrInsecurePath, f.Name)
n := filepath.Join(p, f.Name)
if !strings.HasPrefix(n, p) {
slog.Warn("skipped extracting file outside of context", "name", f.Name)
continue
}
n := filepath.Join(p, f.Name)
if err := os.MkdirAll(filepath.Dir(n), 0o750); err != nil {
return err
}
@@ -260,7 +258,7 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
for _, layer := range layers {
if s := layer.GGML.KV().ChatTemplate(); s != "" {
if t, err := template.Named(s); err != nil {
if t, err := templates.NamedTemplate(s); err != nil {
slog.Debug("template detection", "error", err)
} else {
tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
@@ -293,87 +291,3 @@ func detectContentType(r io.Reader) (string, error) {
return "unknown", nil
}
// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls.
// mxyng: this only really works if the input contains tool calls in some JSON format
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
// create a subtree from the node that ranges over .ToolCalls
tmpl := m.Template.Subtree(func(n parse.Node) bool {
if t, ok := n.(*parse.RangeNode); ok {
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
}
return false
})
if tmpl == nil {
return nil, false
}
var b bytes.Buffer
if err := tmpl.Execute(&b, map[string][]map[string]any{
"ToolCalls": {
{
"Function": map[string]any{
"Name": "@@name@@",
"Arguments": "@@arguments@@",
},
},
},
}); err != nil {
return nil, false
}
var kv map[string]string
// execute the subtree with placeholders to identify the keys
// trim any commands that might exist in the template
if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
return nil, false
}
// find the keys that correspond to the name and arguments fields
var name, arguments string
for k, v := range kv {
switch v {
case "@@name@@":
name = k
case "@@arguments@@":
arguments = k
}
}
var objs []map[string]any
for offset := 0; offset < len(s); {
if err := json.NewDecoder(strings.NewReader(s[offset:])).Decode(&objs); errors.Is(err, io.EOF) {
break
} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
// skip over any syntax errors
offset += int(syntax.Offset)
} else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) {
// skip over any unmarshalable types
offset += int(unmarshalType.Offset)
} else if err != nil {
return nil, false
} else {
// break when an object is decoded
break
}
}
var toolCalls []api.ToolCall
for _, kv := range objs {
var call api.ToolCall
for k, v := range kv {
switch k {
case name:
call.Function.Name = v.(string)
case arguments:
call.Function.Arguments = v.(map[string]any)
}
}
toolCalls = append(toolCalls, call)
}
return toolCalls, len(toolCalls) > 0
}

View File

@@ -3,19 +3,13 @@ package server
import (
"archive/zip"
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"slices"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
)
func createZipFile(t *testing.T, name string) *os.File {
@@ -45,31 +39,13 @@ func TestExtractFromZipFile(t *testing.T) {
cases := []struct {
name string
expect []string
err error
}{
{
name: "good",
expect: []string{"good"},
},
{
name: strings.Join([]string{"path", "..", "to", "good"}, string(os.PathSeparator)),
expect: []string{filepath.Join("to", "good")},
},
{
name: strings.Join([]string{"path", "..", "to", "..", "good"}, string(os.PathSeparator)),
expect: []string{"good"},
},
{
name: strings.Join([]string{"path", "to", "..", "..", "good"}, string(os.PathSeparator)),
expect: []string{"good"},
},
{
name: strings.Join([]string{"..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "bad"}, string(os.PathSeparator)),
err: zip.ErrInsecurePath,
},
{
name: strings.Join([]string{"path", "..", "..", "to", "bad"}, string(os.PathSeparator)),
err: zip.ErrInsecurePath,
name: filepath.Join("..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "bad"),
},
}
@@ -79,7 +55,7 @@ func TestExtractFromZipFile(t *testing.T) {
defer f.Close()
tempDir := t.TempDir()
if err := extractFromZipFile(tempDir, f, func(api.ProgressResponse) {}); !errors.Is(err, tt.err) {
if err := extractFromZipFile(tempDir, f, func(api.ProgressResponse) {}); err != nil {
t.Fatal(err)
}
@@ -114,123 +90,3 @@ func TestExtractFromZipFile(t *testing.T) {
})
}
}
type function struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
}
func readFile(t *testing.T, base, name string) *bytes.Buffer {
t.Helper()
bts, err := os.ReadFile(filepath.Join(base, name))
if err != nil {
t.Fatal(err)
}
return bytes.NewBuffer(bts)
}
func TestExecuteWithTools(t *testing.T) {
p := filepath.Join("testdata", "tools")
cases := []struct {
model string
output string
ok bool
}{
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"command-r-plus", "Action: ```json" + `
[
{
"tool_name": "get_current_weather",
"parameters": {
"format": "fahrenheit",
"location": "San Francisco, CA"
}
},
{
"tool_name": "get_current_weather",
"parameters": {
"format": "celsius",
"location": "Toronto, Canada"
}
}
]
` + "```", true},
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
}
var tools []api.Tool
if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil {
t.Fatal(err)
}
var messages []api.Message
if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil {
t.Fatal(err)
}
calls := []api.ToolCall{
{
Function: function{
Name: "get_current_weather",
Arguments: map[string]any{
"format": "fahrenheit",
"location": "San Francisco, CA",
},
},
},
{
Function: function{
Name: "get_current_weather",
Arguments: map[string]any{
"format": "celsius",
"location": "Toronto, Canada",
},
},
},
}
for _, tt := range cases {
t.Run(tt.model, func(t *testing.T) {
tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
if err != nil {
t.Fatal(err)
}
t.Run("template", func(t *testing.T) {
var actual bytes.Buffer
if err := tmpl.Execute(&actual, template.Values{Tools: tools, Messages: messages}); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("parse", func(t *testing.T) {
m := &Model{Template: tmpl}
actual, ok := m.parseToolCalls(tt.output)
if ok != tt.ok {
t.Fatalf("expected %t, got %t", tt.ok, ok)
}
if tt.ok {
if diff := cmp.Diff(actual, calls); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
})
})
}
}

View File

@@ -103,9 +103,18 @@ func (mp ModelPath) GetShortTagname() string {
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
}
// modelsDir returns the value of the OLLAMA_MODELS environment variable or the user's home directory if OLLAMA_MODELS is not set.
// The models directory is where Ollama stores its model files and manifests.
func modelsDir() (string, error) {
return envconfig.ModelsDir, nil
}
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
func (mp ModelPath) GetManifestPath() (string, error) {
dir := envconfig.ModelsDir
dir, err := modelsDir()
if err != nil {
return "", err
}
return filepath.Join(dir, "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil
}
@@ -118,7 +127,10 @@ func (mp ModelPath) BaseURL() *url.URL {
}
func GetManifestPath() (string, error) {
dir := envconfig.ModelsDir
dir, err := modelsDir()
if err != nil {
return "", err
}
path := filepath.Join(dir, "manifests")
if err := os.MkdirAll(path, 0o755); err != nil {
@@ -129,7 +141,10 @@ func GetManifestPath() (string, error) {
}
func GetBlobsPath(digest string) (string, error) {
dir := envconfig.ModelsDir
dir, err := modelsDir()
if err != nil {
return "", err
}
// only accept actual sha256 digests
pattern := "^sha256[:-][0-9a-fA-F]{64}$"

View File

@@ -1,74 +1,221 @@
package server
import (
"bytes"
"context"
"fmt"
"log/slog"
"strings"
"text/template"
"text/template/parse"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/template"
)
type tokenizeFunc func(context.Context, string) ([]int, error)
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
var system []api.Message
// always include the last message
n := len(msgs) - 1
// in reverse, find all messages that fit into context window
for i := n - 1; i >= 0; i-- {
system = make([]api.Message, 0)
for j := range i {
if msgs[j].Role == "system" {
system = append(system, msgs[j])
// isResponseNode checks if the node contains .Response
func isResponseNode(node *parse.ActionNode) bool {
for _, cmd := range node.Pipe.Cmds {
for _, arg := range cmd.Args {
if fieldNode, ok := arg.(*parse.FieldNode); ok && len(fieldNode.Ident) > 0 {
if fieldNode.Ident[0] == "Response" {
return true
}
}
}
var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
return "", nil, err
}
s, err := tokenize(ctx, b.String())
if err != nil {
return "", nil, err
}
c := len(s)
if m.ProjectorPaths != nil {
for _, m := range msgs[i:] {
// images are represented as 768 sized embeddings
// TODO: get embedding length from project metadata
c += 768 * len(m.Images)
}
}
if c > opts.NumCtx {
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
break
} else {
n = i
}
}
// truncate any messages that do not fit into the context window
var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil {
return "", nil, err
}
for _, m := range msgs[n:] {
for _, i := range m.Images {
images = append(images, llm.ImageData{
ID: len(images),
Data: i,
})
}
}
return b.String(), images, nil
return false
}
// formatTemplateForResponse formats the template AST to:
// 1. remove all nodes after the first .Response (if generate=true)
// 2. add a .Response node to the end if it doesn't exist
// TODO(jmorganca): this should recursively cut the template before the first .Response
func formatTemplateForResponse(tmpl *template.Template, generate bool) {
var found bool
for i, node := range tmpl.Tree.Root.Nodes {
if actionNode, ok := node.(*parse.ActionNode); ok {
if isResponseNode(actionNode) {
found = true
if generate {
tmpl.Tree.Root.Nodes = tmpl.Tree.Root.Nodes[:i+1]
break
}
}
}
}
if !found {
// add the response node if it doesn't exist
responseFieldNode := &parse.FieldNode{NodeType: parse.NodeField, Ident: []string{"Response"}}
responsePipeNode := &parse.PipeNode{NodeType: parse.NodePipe, Cmds: []*parse.CommandNode{{NodeType: parse.NodeCommand, Args: []parse.Node{responseFieldNode}}}}
responseActionNode := &parse.ActionNode{NodeType: parse.NodeAction, Pipe: responsePipeNode}
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, responseActionNode)
}
}
// Prompt renders a prompt from a template. If generate is set to true,
// the response and parts of the template following it are not rendered
func Prompt(tmpl, system, prompt, response string, generate bool) (string, error) {
parsed, err := template.New("").Option("missingkey=zero").Parse(tmpl)
if err != nil {
return "", err
}
formatTemplateForResponse(parsed, generate)
vars := map[string]any{
"System": system,
"Prompt": prompt,
"Response": response,
}
var sb strings.Builder
if err := parsed.Execute(&sb, vars); err != nil {
return "", err
}
return sb.String(), nil
}
func countTokens(tmpl string, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
rendered, err := Prompt(tmpl, system, prompt, response, false)
if err != nil {
return 0, err
}
tokens, err := encode(rendered)
if err != nil {
slog.Error("failed to encode prompt", "err", err)
return 0, err
}
return len(tokens), err
}
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
type prompt struct {
System string
Prompt string
Response string
images []int
tokens int
}
var p prompt
// iterate through messages to build up {system,user,response} prompts
var imgId int
var prompts []prompt
for _, msg := range messages {
switch strings.ToLower(msg.Role) {
case "system":
if p.System != "" || p.Prompt != "" || p.Response != "" {
prompts = append(prompts, p)
p = prompt{}
}
p.System = msg.Content
case "user":
if p.Prompt != "" || p.Response != "" {
prompts = append(prompts, p)
p = prompt{}
}
var sb strings.Builder
for range msg.Images {
fmt.Fprintf(&sb, "[img-%d] ", imgId)
p.images = append(p.images, imgId)
imgId += 1
}
sb.WriteString(msg.Content)
p.Prompt = sb.String()
case "assistant":
if p.Response != "" {
prompts = append(prompts, p)
p = prompt{}
}
p.Response = msg.Content
default:
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
}
}
// add final prompt
if p.System != "" || p.Prompt != "" || p.Response != "" {
prompts = append(prompts, p)
}
// calculate token lengths for each prompt, estimating 768 tokens per images
for i, p := range prompts {
tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode)
if err != nil {
return "", err
}
prompts[i].tokens = tokens + len(prompts[i].images)*768
}
// truncate images and prompts starting from the beginning of the list
// until either one prompt remains or the total tokens fits the context window
// TODO (jmorganca): this doesn't account for the context window room required for the response
for {
var required int
for _, p := range prompts {
required += p.tokens
}
required += 1 // for bos token
if required <= window {
slog.Debug("prompt now fits in context window", "required", required, "window", window)
break
}
prompt := &prompts[0]
if len(prompt.images) > 1 {
img := prompt.images[0]
slog.Debug("prompt longer than context window, removing image", "id", img, "required", required, "window", window)
prompt.images = prompt.images[1:]
prompt.Prompt = strings.Replace(prompt.Prompt, fmt.Sprintf(" [img-%d]", img), "", 1)
prompt.tokens -= 768
continue
}
if len(prompts) > 1 {
slog.Debug("required tokens longer than context window, removing first prompt", "prompt", prompts[0].tokens, "required", required, "window", window)
system := prompt.System
prompts = prompts[1:]
if system != "" && prompts[0].System == "" {
prompts[0].System = system
tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode)
if err != nil {
return "", err
}
prompts[0].tokens = tokens + len(prompts[0].images)*768
}
continue
}
// stop truncating if there's only one prompt left
break
}
var sb strings.Builder
for i, p := range prompts {
// last prompt should leave the response unrendered (for completion)
rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1)
if err != nil {
return "", err
}
sb.WriteString(rendered)
}
return sb.String(), nil
}

View File

@@ -1,209 +1,204 @@
package server
import (
"bytes"
"context"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
)
func TestChatPrompt(t *testing.T) {
type expect struct {
prompt string
images [][]byte
}
cases := []struct {
name string
limit int
msgs []api.Message
expect
func TestPrompt(t *testing.T) {
tests := []struct {
name string
template string
system string
prompt string
response string
generate bool
want string
}{
{
name: "messages",
limit: 64,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
},
name: "simple prompt",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
},
{
name: "truncate messages",
limit: 1,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "A test. And a thumping good one at that, I'd wager. ",
},
name: "implicit response",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]I don't know.",
},
{
name: "truncate messages with image",
limit: 64,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("something")}},
},
expect: expect{
prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("something"),
},
},
name: "response",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
},
{
name: "truncate messages with images",
limit: 64,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
},
expect: expect{
prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("somethingelse"),
},
},
name: "cut",
template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
generate: true,
want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.",
},
{
name: "messages with images",
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
},
expect: expect{
prompt: "[img-0] You're a test, Harry! I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("something"),
[]byte("somethingelse"),
},
},
},
{
name: "message with image tag",
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
},
expect: expect{
prompt: "You're a test, Harry! [img-0] I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("something"),
[]byte("somethingelse"),
},
},
},
{
name: "messages with interleaved images",
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "user", Images: []api.ImageData{[]byte("something")}},
{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You're a test, Harry!\n\n[img-0]\n\n[img-1] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("something"),
[]byte("somethingelse"),
},
},
},
{
name: "truncate message with interleaved images",
limit: 1024,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "user", Images: []api.ImageData{[]byte("something")}},
{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "[img-0] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("somethingelse"),
},
},
},
{
name: "message with system prompt",
limit: 2048,
msgs: []api.Message{
{Role: "system", Content: "You are the Test Who Lived."},
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You are the Test Who Lived. You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
},
},
{
name: "out of order system",
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "system", Content: "You are the Test Who Lived."},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
},
name: "nocut",
template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.</assistant>",
},
}
tmpl, err := template.Parse(`
{{- if .System }}{{ .System }} {{ end }}
{{- if .Prompt }}{{ .Prompt }} {{ end }}
{{- if .Response }}{{ .Response }} {{ end }}`)
if err != nil {
t.Fatal(err)
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := Prompt(tc.template, tc.system, tc.prompt, tc.response, tc.generate)
if err != nil {
t.Fatal(err)
t.Errorf("error = %v", err)
}
if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
if len(images) != len(tt.images) {
t.Fatalf("expected %d images, got %d", len(tt.images), len(images))
}
for i := range images {
if images[i].ID != i {
t.Errorf("expected ID %d, got %d", i, images[i].ID)
}
if !bytes.Equal(images[i].Data, tt.images[i]) {
t.Errorf("expected %q, got %q", tt.images[i], images[i])
}
if got != tc.want {
t.Errorf("got = %v, want %v", got, tc.want)
}
})
}
}
func TestChatPrompt(t *testing.T) {
tests := []struct {
name string
template string
messages []api.Message
window int
want string
}{
{
name: "simple prompt",
template: "[INST] {{ .Prompt }} [/INST]",
messages: []api.Message{
{Role: "user", Content: "Hello"},
},
window: 1024,
want: "[INST] Hello [/INST]",
},
{
name: "with system message",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello"},
},
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
},
{
name: "with response",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }}",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "I am?"},
},
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST] I am?",
},
{
name: "with implicit response",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "I am?"},
},
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]I am?",
},
{
name: "with conversation",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "What are the potion ingredients?"},
{Role: "assistant", Content: "sugar"},
{Role: "user", Content: "Anything else?"},
},
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> What are the potion ingredients? [/INST] sugar [INST] Anything else? [/INST] ",
},
{
name: "with truncation",
template: "{{ .System }} {{ .Prompt }} {{ .Response }} ",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "I am?"},
{Role: "user", Content: "Why is the sky blue?"},
{Role: "assistant", Content: "The sky is blue from rayleigh scattering"},
},
window: 10,
want: "You are a Wizard. Why is the sky blue? The sky is blue from rayleigh scattering",
},
{
name: "images",
template: "{{ .System }} {{ .Prompt }}",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}},
},
window: 1024,
want: "You are a Wizard. [img-0] Hello",
},
{
name: "images truncated",
template: "{{ .System }} {{ .Prompt }}",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}},
},
window: 1024,
want: "You are a Wizard. [img-0] [img-1] Hello",
},
{
name: "empty list",
template: "{{ .System }} {{ .Prompt }}",
messages: []api.Message{},
window: 1024,
want: "",
},
{
name: "empty prompt",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
messages: []api.Message{
{Role: "user", Content: ""},
},
window: 1024,
want: "",
},
}
encode := func(s string) ([]int, error) {
words := strings.Fields(s)
return make([]int, len(words)), nil
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode)
if err != nil {
t.Errorf("error = %v", err)
}
if got != tc.want {
t.Errorf("got: %q, want: %q", got, tc.want)
}
})
}

View File

@@ -1,13 +1,13 @@
package server
import (
"bytes"
"cmp"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"log/slog"
"math"
"net"
@@ -17,6 +17,7 @@ import (
"os/signal"
"path/filepath"
"slices"
"strconv"
"strings"
"syscall"
"time"
@@ -30,7 +31,6 @@ import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
@@ -55,7 +55,7 @@ func init() {
gin.SetMode(mode)
}
var errRequired = errors.New("is required")
var defaultSessionDuration = 5 * time.Minute
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions()
@@ -70,225 +70,277 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
return opts, nil
}
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
if name == "" {
return nil, nil, nil, fmt.Errorf("model %w", errRequired)
}
model, err := GetModel(name)
if err != nil {
return nil, nil, nil, err
}
if err := model.CheckCapabilities(caps...); err != nil {
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
}
opts, err := modelOptions(model, requestOpts)
if err != nil {
return nil, nil, nil, err
}
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
var runner *runnerRef
select {
case runner = <-runnerCh:
case err = <-errCh:
return nil, nil, nil, err
}
return runner.llama, model, &opts, nil
func isSupportedImageType(image []byte) bool {
contentType := http.DetectContentType(image)
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
return slices.Contains(allowedTypes, contentType)
}
func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.GenerateRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Format != "" && req.Format != "json" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
// validate the request
switch {
case req.Model == "":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
} else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
case len(req.Format) > 0 && req.Format != "json":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
return
case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
return
}
caps := []Capability{CapabilityCompletion}
if req.Suffix != "" {
caps = append(caps, CapabilityInsert)
for _, img := range req.Images {
if !isSupportedImageType(img) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
return
}
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
return
} else if err != nil {
handleScheduleError(c, req.Model, err)
model, err := GetModel(req.Model)
if err != nil {
var pErr *fs.PathError
if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
checkpointLoaded := time.Now()
if model.IsEmbedding() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support generate"})
return
}
if req.Prompt == "" {
opts, err := modelOptions(model, req.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
handleErrorResponse(c, err)
return
}
// an empty request loads the model
// note: for a short while template was used in lieu
// of `raw` mode so we need to check for it too
if req.Prompt == "" && req.Template == "" && req.System == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Model: req.Model,
Done: true,
DoneReason: "load",
})
return
}
images := make([]llm.ImageData, len(req.Images))
for i := range req.Images {
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
}
checkpointLoaded := time.Now()
prompt := req.Prompt
if !req.Raw {
tmpl := m.Template
if req.Template != "" {
tmpl, err = template.Parse(req.Template)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var prompt string
switch {
case req.Raw:
prompt = req.Prompt
case req.Prompt != "":
if req.Template == "" {
req.Template = model.Template
}
var b bytes.Buffer
if req.Context != nil {
s, err := r.Detokenize(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
b.WriteString(s)
if req.System == "" {
req.System = model.System
}
var values template.Values
if req.Suffix != "" {
values.Prompt = prompt
values.Suffix = req.Suffix
} else {
var msgs []api.Message
if req.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
} else if m.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: m.System})
}
slog.Debug("generate handler", "prompt", req.Prompt)
slog.Debug("generate handler", "template", req.Template)
slog.Debug("generate handler", "system", req.System)
for _, i := range images {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
}
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
var sb strings.Builder
for i := range req.Images {
fmt.Fprintf(&sb, "[img-%d] ", i)
}
if err := tmpl.Execute(&b, values); err != nil {
sb.WriteString(req.Prompt)
p, err := Prompt(req.Template, req.System, sb.String(), "", true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
prompt = b.String()
sb.Reset()
if req.Context != nil {
prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sb.WriteString(prev)
}
sb.WriteString(p)
prompt = sb.String()
}
slog.Debug("generate request", "prompt", prompt, "images", images)
slog.Debug("generate handler", "prompt", prompt)
ch := make(chan any)
var generated strings.Builder
go func() {
// TODO (jmorganca): avoid building the response twice both here and below
var sb strings.Builder
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
}, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{
fn := func(r llm.CompletionResponse) {
// Build up the full response
if _, err := generated.WriteString(r.Content); err != nil {
ch <- gin.H{"error": err.Error()}
return
}
resp := api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Response: cr.Content,
Done: cr.Done,
DoneReason: cr.DoneReason,
Done: r.Done,
Response: r.Content,
DoneReason: r.DoneReason,
Metrics: api.Metrics{
PromptEvalCount: cr.PromptEvalCount,
PromptEvalDuration: cr.PromptEvalDuration,
EvalCount: cr.EvalCount,
EvalDuration: cr.EvalDuration,
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
}
if _, err := sb.WriteString(cr.Content); err != nil {
ch <- gin.H{"error": err.Error()}
}
if cr.Done {
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if r.Done {
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// TODO (jmorganca): encode() should not strip special tokens
tokens, err := runner.llama.Tokenize(c.Request.Context(), p)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
res.Context = append(req.Context, tokens...)
resp.Context = append(req.Context, tokens...)
}
}
ch <- res
}); err != nil {
ch <- resp
}
var images []llm.ImageData
for i := range req.Images {
images = append(images, llm.ImageData{
ID: i,
Data: req.Images[i],
})
}
// Start prediction
req := llm.CompletionRequest{
Prompt: prompt,
Format: req.Format,
Images: images,
Options: opts,
}
if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
var r api.GenerateResponse
// Accumulate responses into the final response
var final api.GenerateResponse
var sb strings.Builder
for rr := range ch {
switch t := rr.(type) {
for resp := range ch {
switch r := resp.(type) {
case api.GenerateResponse:
sb.WriteString(t.Response)
r = t
sb.WriteString(r.Response)
final = r
case gin.H:
msg, ok := t["error"].(string)
if !ok {
msg = "unexpected error format in response"
if errorMsg, ok := r["error"].(string); ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
return
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
return
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
return
}
}
r.Response = sb.String()
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
r.ToolCalls = toolCalls
r.Response = ""
}
c.JSON(http.StatusOK, r)
final.Response = sb.String()
c.JSON(http.StatusOK, final)
return
}
streamResponse(c, ch)
}
func (s *Server) EmbedHandler(c *gin.Context) {
var req api.EmbedRequest
func getDefaultSessionDuration() time.Duration {
if envconfig.KeepAlive != "" {
v, err := strconv.Atoi(envconfig.KeepAlive)
if err != nil {
d, err := time.ParseDuration(envconfig.KeepAlive)
if err != nil {
return defaultSessionDuration
}
if d < 0 {
return time.Duration(math.MaxInt64)
}
return d
}
d := time.Duration(v) * time.Second
if d < 0 {
return time.Duration(math.MaxInt64)
}
return d
}
return defaultSessionDuration
}
func (s *Server) EmbeddingsHandler(c *gin.Context) {
var req api.EmbeddingRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
@@ -299,122 +351,41 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
truncate := true
if req.Truncate != nil && !*req.Truncate {
truncate = false
}
var input []string
switch i := req.Input.(type) {
case string:
if len(i) > 0 {
input = append(input, i)
}
case []any:
for _, v := range i {
if _, ok := v.(string); !ok {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return
}
input = append(input, v.(string))
}
default:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
if len(input) == 0 {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
model, err := GetModel(req.Model)
if err != nil {
handleScheduleError(c, req.Model, err)
var pErr *fs.PathError
if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
kvData, err := getKVData(m.ModelPath, false)
opts, err := modelOptions(model, req.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
for i, s := range input {
tokens, err := r.Tokenize(c.Request.Context(), s)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if len(tokens) > ctxLen {
if !truncate {
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
return
}
tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
input[i] = s
}
embeddings, err := r.Embed(c.Request.Context(), input)
if err != nil {
slog.Error("embedding generation failed", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
}
for i, e := range embeddings {
embeddings[i] = normalize(e)
}
resp := api.EmbedResponse{
Model: req.Model,
Embeddings: embeddings,
}
c.JSON(http.StatusOK, resp)
}
func normalize(vec []float32) []float32 {
var sum float32
for _, v := range vec {
sum += v * v
}
norm := float32(0.0)
if sum > 0 {
norm = float32(1.0 / math.Sqrt(float64(sum)))
}
for i := range vec {
vec[i] *= norm
}
return vec
}
func (s *Server) EmbeddingsHandler(c *gin.Context) {
var req api.EmbeddingRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
handleErrorResponse(c, err)
return
}
@@ -424,20 +395,13 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return
}
embedding := make([]float64, len(embeddings[0]))
for i, v := range embeddings[0] {
embedding[i] = float64(v)
}
resp := api.EmbeddingResponse{
Embedding: embedding,
}
@@ -715,9 +679,13 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
m.System = req.System
}
msgs := make([]api.Message, len(m.Messages))
for i, msg := range m.Messages {
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
if req.Template != "" {
m.Template = req.Template
}
msgs := make([]api.Message, 0)
for _, msg := range m.Messages {
msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
}
n := model.ParseName(req.Model)
@@ -733,7 +701,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
resp := &api.ShowResponse{
License: strings.Join(m.License, "\n"),
System: m.System,
Template: m.Template.String(),
Template: m.Template,
Details: modelDetails,
Messages: msgs,
ModifiedAt: manifest.fi.ModTime(),
@@ -1060,7 +1028,6 @@ func (s *Server) GenerateRoutes() http.Handler {
r.POST("/api/pull", s.PullModelHandler)
r.POST("/api/generate", s.GenerateHandler)
r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embed", s.EmbedHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler)
r.POST("/api/create", s.CreateModelHandler)
r.POST("/api/push", s.PushModelHandler)
@@ -1072,11 +1039,7 @@ func (s *Server) GenerateRoutes() http.Handler {
r.GET("/api/ps", s.ProcessHandler)
// Compatibility endpoints
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
r.POST("/v1/chat/completions", openai.Middleware(), s.ChatHandler)
for _, method := range []string{http.MethodGet, http.MethodHead} {
r.Handle(method, "/", func(c *gin.Context) {
@@ -1282,67 +1245,139 @@ func (s *Server) ProcessHandler(c *gin.Context) {
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
}
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) {
encode := func(s string) ([]int, error) {
return runner.llama.Tokenize(ctx, s)
}
prompt, err := ChatPrompt(template, messages, numCtx, encode)
if err != nil {
return "", err
}
return prompt, nil
}
func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.ChatRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
caps := []Capability{CapabilityCompletion}
if req.Tools != nil {
caps = append(caps, CapabilityTools)
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
// validate the request
switch {
case req.Model == "":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
} else if err != nil {
handleScheduleError(c, req.Model, err)
case len(req.Format) > 0 && req.Format != "json":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
return
}
checkpointLoaded := time.Now()
if len(req.Messages) == 0 {
c.JSON(http.StatusOK, api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant"},
Done: true,
DoneReason: "load",
})
model, err := GetModel(req.Model)
if err != nil {
var pErr *fs.PathError
if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if req.Messages[0].Role != "system" && m.System != "" {
req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...)
if model.IsEmbedding() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support chat"})
return
}
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages, req.Tools)
opts, err := modelOptions(model, req.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
slog.Debug("chat request", "images", len(images), "prompt", prompt)
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
handleErrorResponse(c, err)
return
}
checkpointLoaded := time.Now()
// if the first message is not a system message, then add the model's default system message
if len(req.Messages) > 0 && req.Messages[0].Role != "system" {
req.Messages = append([]api.Message{
{
Role: "system",
Content: model.System,
},
}, req.Messages...)
}
prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// an empty request loads the model
if len(req.Messages) == 0 || prompt == "" {
resp := api.ChatResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model,
Done: true,
DoneReason: "load",
Message: api.Message{Role: "assistant"},
}
c.JSON(http.StatusOK, resp)
return
}
// only send images that are in the prompt
var i int
var images []llm.ImageData
for _, m := range req.Messages {
for _, img := range m.Images {
if !isSupportedImageType(img) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
return
}
if strings.Contains(prompt, fmt.Sprintf("[img-%d]", i)) {
images = append(images, llm.ImageData{Data: img, ID: i})
}
i += 1
}
}
slog.Debug("chat handler", "prompt", prompt, "images", len(images))
ch := make(chan any)
go func() {
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
fn := func(r llm.CompletionResponse) {
resp := api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
@@ -1357,62 +1392,62 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
if r.Done {
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
ch <- res
}); err != nil {
ch <- resp
}
if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Format: req.Format,
Images: images,
Options: opts,
}, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
var resp api.ChatResponse
// Accumulate responses into the final response
var final api.ChatResponse
var sb strings.Builder
for rr := range ch {
switch t := rr.(type) {
for resp := range ch {
switch r := resp.(type) {
case api.ChatResponse:
sb.WriteString(t.Message.Content)
resp = t
sb.WriteString(r.Message.Content)
final = r
case gin.H:
msg, ok := t["error"].(string)
if !ok {
msg = "unexpected error format in response"
if errorMsg, ok := r["error"].(string); ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
return
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
return
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
return
}
}
resp.Message.Content = sb.String()
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
resp.Message.ToolCalls = toolCalls
resp.Message.Content = ""
}
c.JSON(http.StatusOK, resp)
final.Message = api.Message{Role: "assistant", Content: sb.String()}
c.JSON(http.StatusOK, final)
return
}
streamResponse(c, ch)
}
func handleScheduleError(c *gin.Context, name string, err error) {
switch {
case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
case errors.Is(err, context.Canceled):
func handleErrorResponse(c *gin.Context, err error) {
if errors.Is(err, context.Canceled) {
c.JSON(499, gin.H{"error": "request canceled"})
case errors.Is(err, ErrMaxQueue):
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
case errors.Is(err, os.ErrNotExist):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if errors.Is(err, ErrMaxQueue) {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}

View File

@@ -85,8 +85,6 @@ func checkFileExists(t *testing.T, p string, expect []string) {
}
func TestCreateFromBin(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
@@ -113,8 +111,6 @@ func TestCreateFromBin(t *testing.T) {
}
func TestCreateFromModel(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
@@ -156,8 +152,6 @@ func TestCreateFromModel(t *testing.T) {
}
func TestCreateRemovesLayers(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
@@ -205,8 +199,6 @@ func TestCreateRemovesLayers(t *testing.T) {
}
func TestCreateUnsetsSystem(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
@@ -263,8 +255,6 @@ func TestCreateUnsetsSystem(t *testing.T) {
}
func TestCreateMergeParameters(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
@@ -368,8 +358,6 @@ func TestCreateMergeParameters(t *testing.T) {
}
func TestCreateReplacesMessages(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
@@ -446,8 +434,6 @@ func TestCreateReplacesMessages(t *testing.T) {
}
func TestCreateTemplateSystem(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
@@ -494,8 +480,6 @@ func TestCreateTemplateSystem(t *testing.T) {
}
func TestCreateLicenses(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
@@ -542,8 +526,6 @@ func TestCreateLicenses(t *testing.T) {
}
func TestCreateDetectTemplate(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
@@ -563,9 +545,9 @@ func TestCreateDetectTemplate(t *testing.T) {
}
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-2f8e594e6f34b1b4d36a246628eeb3365ce442303d656f1fcc69e821722acea0"),
filepath.Join(p, "blobs", "sha256-542b217f179c7825eeb5bca3c77d2b75ed05bafbd3451d9188891a60a85337c6"),
filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
filepath.Join(p, "blobs", "sha256-c608dc615584cd20d9d830363dabf8a4783ae5d34245c3d8c115edb3bc7b28e4"),
filepath.Join(p, "blobs", "sha256-f836ee110db21567f826332e4cedd746c06d10664fd5a9ea3659e3683a944510"),
})
})

View File

@@ -8,15 +8,12 @@ import (
"path/filepath"
"testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
func TestDelete(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
@@ -80,8 +77,6 @@ func TestDelete(t *testing.T) {
}
func TestDeleteDuplicateLayers(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server

View File

@@ -1,712 +0,0 @@
package server
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
)
type mockRunner struct {
llm.LlamaServer
// CompletionRequest is only valid until the next call to Completion
llm.CompletionRequest
llm.CompletionResponse
}
func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
m.CompletionRequest = r
fn(m.CompletionResponse)
return nil
}
func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) {
for range strings.Fields(s) {
tokens = append(tokens, len(tokens))
}
return
}
func newMockServer(mock *mockRunner) func(gpu.GpuInfoList, string, *llm.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
return func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, projectors, system []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
return mock, nil
}
}
func TestGenerateChat(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: "stop",
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: gpu.GetGPUInfo,
getCpuFn: gpu.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
// add 10ms delay to simulate loading
time.Sleep(10 * time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
},
},
}
go s.sched.Run(context.TODO())
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test",
Modelfile: fmt.Sprintf(`FROM %s
TEMPLATE """
{{- if .System }}System: {{ .System }} {{ end }}
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
`, createBinFile(t, llm.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []llm.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("missing body", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, nil)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing model", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing capabilities chat", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "bert",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
"general.architecture": "bert",
"bert.pooling_type": uint32(0),
}, []llm.Tensor{})),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "bert",
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support chat"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("load model", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
var actual api.ChatResponse
if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
t.Fatal(err)
}
if actual.Model != "test" {
t.Errorf("expected model test, got %s", actual.Model)
}
if !actual.Done {
t.Errorf("expected done true, got false")
}
if actual.DoneReason != "load" {
t.Errorf("expected done reason load, got %s", actual.DoneReason)
}
})
checkChatResponse := func(t *testing.T, body io.Reader, model, content string) {
t.Helper()
var actual api.ChatResponse
if err := json.NewDecoder(body).Decode(&actual); err != nil {
t.Fatal(err)
}
if actual.Model != model {
t.Errorf("expected model test, got %s", actual.Model)
}
if !actual.Done {
t.Errorf("expected done false, got true")
}
if actual.DoneReason != "stop" {
t.Errorf("expected done reason stop, got %s", actual.DoneReason)
}
if diff := cmp.Diff(actual.Message, api.Message{
Role: "assistant",
Content: content,
}); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
if actual.PromptEvalCount == 0 {
t.Errorf("expected prompt eval count > 0, got 0")
}
if actual.PromptEvalDuration == 0 {
t.Errorf("expected prompt eval duration > 0, got 0")
}
if actual.EvalCount == 0 {
t.Errorf("expected eval count > 0, got 0")
}
if actual.EvalDuration == 0 {
t.Errorf("expected eval duration > 0, got 0")
}
if actual.LoadDuration == 0 {
t.Errorf("expected load duration > 0, got 0")
}
if actual.TotalDuration == 0 {
t.Errorf("expected total duration > 0, got 0")
}
}
mock.CompletionResponse.Content = "Hi!"
t.Run("messages", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test",
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkChatResponse(t, w.Body, "test", "Hi!")
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test-system",
Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("messages with model system", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-system",
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkChatResponse(t, w.Body, "test-system", "Hi!")
})
mock.CompletionResponse.Content = "Abra kadabra!"
t.Run("messages with system", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-system",
Messages: []api.Message{
{Role: "system", Content: "You can perform magic tricks."},
{Role: "user", Content: "Hello!"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
})
t.Run("messages with interleaved system", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-system",
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
{Role: "assistant", Content: "I can help you with that."},
{Role: "system", Content: "You can perform magic tricks."},
{Role: "user", Content: "Help me write tests."},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
})
}
func TestGenerate(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: "stop",
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: gpu.GetGPUInfo,
getCpuFn: gpu.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
req.successCh <- &runnerRef{
llama: &mock,
}
},
},
}
go s.sched.Run(context.TODO())
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test",
Modelfile: fmt.Sprintf(`FROM %s
TEMPLATE """
{{- if .System }}System: {{ .System }} {{ end }}
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
`, createBinFile(t, llm.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []llm.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("missing body", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, nil)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing model", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing capabilities generate", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "bert",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
"general.architecture": "bert",
"bert.pooling_type": uint32(0),
}, []llm.Tensor{})),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
w = createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "bert",
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support generate"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing capabilities suffix", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "def add(",
Suffix: " return c",
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("load model", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
var actual api.GenerateResponse
if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
t.Fatal(err)
}
if actual.Model != "test" {
t.Errorf("expected model test, got %s", actual.Model)
}
if !actual.Done {
t.Errorf("expected done true, got false")
}
if actual.DoneReason != "load" {
t.Errorf("expected done reason load, got %s", actual.DoneReason)
}
})
checkGenerateResponse := func(t *testing.T, body io.Reader, model, content string) {
t.Helper()
var actual api.GenerateResponse
if err := json.NewDecoder(body).Decode(&actual); err != nil {
t.Fatal(err)
}
if actual.Model != model {
t.Errorf("expected model test, got %s", actual.Model)
}
if !actual.Done {
t.Errorf("expected done false, got true")
}
if actual.DoneReason != "stop" {
t.Errorf("expected done reason stop, got %s", actual.DoneReason)
}
if actual.Response != content {
t.Errorf("expected response %s, got %s", content, actual.Response)
}
if actual.Context == nil {
t.Errorf("expected context not nil")
}
if actual.PromptEvalCount == 0 {
t.Errorf("expected prompt eval count > 0, got 0")
}
if actual.PromptEvalDuration == 0 {
t.Errorf("expected prompt eval duration > 0, got 0")
}
if actual.EvalCount == 0 {
t.Errorf("expected eval count > 0, got 0")
}
if actual.EvalDuration == 0 {
t.Errorf("expected eval duration > 0, got 0")
}
if actual.LoadDuration == 0 {
t.Errorf("expected load duration > 0, got 0")
}
if actual.TotalDuration == 0 {
t.Errorf("expected total duration > 0, got 0")
}
}
mock.CompletionResponse.Content = "Hi!"
t.Run("prompt", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "Hello!",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkGenerateResponse(t, w.Body, "test", "Hi!")
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test-system",
Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("prompt with model system", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-system",
Prompt: "Hello!",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkGenerateResponse(t, w.Body, "test-system", "Hi!")
})
mock.CompletionResponse.Content = "Abra kadabra!"
t.Run("prompt with system", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-system",
Prompt: "Hello!",
System: "You can perform magic tricks.",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
})
t.Run("prompt with template", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-system",
Prompt: "Help me write tests.",
System: "You can perform magic tricks.",
Template: `{{- if .System }}{{ .System }} {{ end }}
{{- if .Prompt }}### USER {{ .Prompt }} {{ end }}
{{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`,
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test-suffix",
Modelfile: `FROM test
TEMPLATE """{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
{{- else }}{{ .Prompt }}
{{- end }}"""`,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("prompt with suffix", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-suffix",
Prompt: "def add(",
Suffix: " return c",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("prompt without suffix", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-suffix",
Prompt: "def add(",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("raw", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-system",
Prompt: "Help me write tests.",
Raw: true,
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}

View File

@@ -7,14 +7,11 @@ import (
"slices"
"testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
func TestList(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("OLLAMA_MODELS", t.TempDir())
envconfig.LoadConfig()

View File

@@ -7,7 +7,6 @@ import (
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"net/http/httptest"
"os"
@@ -21,7 +20,6 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
@@ -107,24 +105,6 @@ func Test_Routes(t *testing.T) {
assert.Empty(t, len(modelList.Models))
},
},
{
Name: "openai empty list",
Method: http.MethodGet,
Path: "/v1/models",
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
assert.Equal(t, "application/json", contentType)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
var modelList openai.ListCompletion
err = json.Unmarshal(body, &modelList)
require.NoError(t, err)
assert.Equal(t, "list", modelList.Object)
assert.Empty(t, modelList.Data)
},
},
{
Name: "Tags Handler (yes tags)",
Method: http.MethodGet,
@@ -148,25 +128,6 @@ func Test_Routes(t *testing.T) {
assert.Equal(t, "test-model:latest", modelList.Models[0].Name)
},
},
{
Name: "openai list models with tags",
Method: http.MethodGet,
Path: "/v1/models",
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
assert.Equal(t, "application/json", contentType)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
var modelList openai.ListCompletion
err = json.Unmarshal(body, &modelList)
require.NoError(t, err)
assert.Len(t, modelList.Data, 1)
assert.Equal(t, "test-model:latest", modelList.Data[0].Id)
assert.Equal(t, "library", modelList.Data[0].OwnedBy)
},
},
{
Name: "Create Model Handler",
Method: http.MethodPost,
@@ -255,95 +216,6 @@ func Test_Routes(t *testing.T) {
assert.InDelta(t, 0, showResp.ModelInfo["general.parameter_count"], 1e-9, "Parameter count should be 0")
},
},
{
Name: "openai retrieve model handler",
Method: http.MethodGet,
Path: "/v1/models/show-model",
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
assert.Equal(t, "application/json", contentType)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
var retrieveResp api.RetrieveModelResponse
err = json.Unmarshal(body, &retrieveResp)
require.NoError(t, err)
assert.Equal(t, "show-model", retrieveResp.Id)
assert.Equal(t, "library", retrieveResp.OwnedBy)
},
},
{
Name: "Embed Handler Empty Input",
Method: http.MethodPost,
Path: "/api/embed",
Setup: func(t *testing.T, req *http.Request) {
embedReq := api.EmbedRequest{
Model: "t-bone",
Input: "",
}
jsonData, err := json.Marshal(embedReq)
require.NoError(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
if contentType != "application/json; charset=utf-8" {
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
var embedResp api.EmbedResponse
err = json.Unmarshal(body, &embedResp)
if err != nil {
t.Fatal(err)
}
if embedResp.Model != "t-bone" {
t.Fatalf("expected model t-bone, got %s", embedResp.Model)
}
if embedResp.Embeddings == nil {
t.Fatalf("expected embeddings to not be nil, got %v", embedResp.Embeddings)
}
if len(embedResp.Embeddings) != 0 {
t.Fatalf("expected embeddings to be empty, got %v", embedResp.Embeddings)
}
},
},
{
Name: "Embed Handler Invalid Input",
Method: http.MethodPost,
Path: "/api/embed",
Setup: func(t *testing.T, req *http.Request) {
embedReq := api.EmbedRequest{
Model: "t-bone",
Input: 2,
}
jsonData, err := json.Marshal(embedReq)
require.NoError(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
if contentType != "application/json; charset=utf-8" {
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
}
_, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected status code 400, got %d", resp.StatusCode)
}
},
},
}
t.Setenv("OLLAMA_MODELS", t.TempDir())
@@ -492,38 +364,3 @@ func TestShow(t *testing.T) {
t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"])
}
}
func TestNormalize(t *testing.T) {
type testCase struct {
input []float32
}
testCases := []testCase{
{input: []float32{1}},
{input: []float32{0, 1, 2, 3}},
{input: []float32{0.1, 0.2, 0.3}},
{input: []float32{-0.1, 0.2, 0.3, -0.4}},
{input: []float32{0, 0, 0}},
}
isNormalized := func(vec []float32) (res bool) {
sum := 0.0
for _, v := range vec {
sum += float64(v * v)
}
if math.Abs(sum-1) > 1e-6 {
return sum == 0
} else {
return true
}
}
for _, tc := range testCases {
t.Run("", func(t *testing.T) {
normalized := normalize(tc.input)
if !isNormalized(normalized) {
t.Errorf("Vector %v is not normalized", tc.input)
}
})
}
}

View File

@@ -24,7 +24,7 @@ type LlmRequest struct {
model *Model
opts api.Options
origNumCtx int // Track the initial ctx request
sessionDuration *api.Duration
sessionDuration time.Duration
successCh chan *runnerRef
errCh chan error
schedAttempts uint
@@ -75,7 +75,7 @@ func InitScheduler(ctx context.Context) *Scheduler {
}
// context must be canceled to decrement ref count and release the runner
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) {
if opts.NumCtx < 4 {
opts.NumCtx = 4
}
@@ -133,6 +133,10 @@ func (s *Scheduler) processPending(ctx context.Context) {
numParallel = 1
slog.Warn("multimodal models don't support parallel requests yet")
}
// Keep NumCtx and numParallel in sync
if numParallel > 1 {
pending.opts.NumCtx = pending.origNumCtx * numParallel
}
for {
var runnerToExpire *runnerRef
@@ -193,10 +197,9 @@ func (s *Scheduler) processPending(ctx context.Context) {
// simplifying assumption of defaultParallel when in CPU mode
if numParallel <= 0 {
numParallel = defaultParallel
pending.opts.NumCtx = pending.origNumCtx * numParallel
}
pending.opts.NumCtx = pending.origNumCtx * numParallel
if loadedCount == 0 {
slog.Debug("cpu mode with first model, loading")
s.loadFn(pending, ggml, gpus, numParallel)
@@ -386,9 +389,7 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
runner.expireTimer.Stop()
runner.expireTimer = nil
}
if pending.sessionDuration != nil {
runner.sessionDuration = pending.sessionDuration.Duration
}
runner.sessionDuration = pending.sessionDuration
pending.successCh <- runner
go func() {
<-pending.ctx.Done()
@@ -401,10 +402,6 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
if numParallel < 1 {
numParallel = 1
}
sessionDuration := envconfig.KeepAlive
if req.sessionDuration != nil {
sessionDuration = req.sessionDuration.Duration
}
llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
@@ -422,7 +419,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
modelPath: req.model.ModelPath,
llama: llama,
Options: &req.opts,
sessionDuration: sessionDuration,
sessionDuration: req.sessionDuration,
gpus: gpus,
estimatedVRAM: llama.EstimatedVRAM(),
estimatedTotal: llama.EstimatedTotal(),

View File

@@ -44,7 +44,7 @@ func TestLoad(t *testing.T) {
opts: api.DefaultOptions(),
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
sessionDuration: &api.Duration{Duration: 2 * time.Second},
sessionDuration: 2,
}
// Fail to load model first
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
@@ -142,7 +142,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
ctx: scenario.ctx,
model: model,
opts: api.DefaultOptions(),
sessionDuration: &api.Duration{Duration: 5 * time.Millisecond},
sessionDuration: 5 * time.Millisecond,
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
}
@@ -156,18 +156,18 @@ func TestRequests(t *testing.T) {
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
scenario1a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
scenario1a.req.sessionDuration = 5 * time.Millisecond
scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
scenario1b.req.model = scenario1a.req.model
scenario1b.ggml = scenario1a.ggml
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
scenario1b.req.sessionDuration = 0
// simple reload of same model
scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
tmpModel := *scenario1a.req.model
scenario2a.req.model = &tmpModel
scenario2a.ggml = scenario1a.ggml
scenario2a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
scenario2a.req.sessionDuration = 5 * time.Millisecond
// Multiple loaded models
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
@@ -199,8 +199,6 @@ func TestRequests(t *testing.T) {
require.Equal(t, resp.llama, scenario1a.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario1a.req.errCh)
case err := <-scenario1a.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
}
@@ -214,8 +212,6 @@ func TestRequests(t *testing.T) {
require.Equal(t, resp.llama, scenario1a.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario1b.req.errCh)
case err := <-scenario1b.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
}
@@ -234,8 +230,6 @@ func TestRequests(t *testing.T) {
require.Equal(t, resp.llama, scenario2a.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario2a.req.errCh)
case err := <-scenario2a.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
}
@@ -252,8 +246,6 @@ func TestRequests(t *testing.T) {
require.Equal(t, resp.llama, scenario3a.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3a.req.errCh)
case err := <-scenario3a.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
}
@@ -270,8 +262,6 @@ func TestRequests(t *testing.T) {
require.Equal(t, resp.llama, scenario3b.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3b.req.errCh)
case err := <-scenario3b.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
}
@@ -288,8 +278,6 @@ func TestRequests(t *testing.T) {
require.Equal(t, resp.llama, scenario3c.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3c.req.errCh)
case err := <-scenario3c.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
}
@@ -330,11 +318,11 @@ func TestGetRunner(t *testing.T) {
defer done()
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
scenario1a.req.sessionDuration = 0
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
scenario1b.req.sessionDuration = 0
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
scenario1c.req.sessionDuration = &api.Duration{Duration: 0}
scenario1c.req.sessionDuration = 0
envconfig.MaxQueuedRequests = 1
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
@@ -414,7 +402,7 @@ func TestPrematureExpired(t *testing.T) {
case <-ctx.Done():
t.Fatal("timeout")
}
time.Sleep(scenario1a.req.sessionDuration.Duration)
time.Sleep(scenario1a.req.sessionDuration)
scenario1a.ctxDone()
time.Sleep(20 * time.Millisecond)
require.LessOrEqual(t, len(s.finishedReqCh), 1)
@@ -435,7 +423,7 @@ func TestUseLoadedRunner(t *testing.T) {
ctx: ctx,
opts: api.DefaultOptions(),
successCh: make(chan *runnerRef, 1),
sessionDuration: &api.Duration{Duration: 2},
sessionDuration: 2,
}
finished := make(chan *LlmRequest)
llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
@@ -626,7 +614,7 @@ func TestAlreadyCanceled(t *testing.T) {
dctx, done2 := context.WithCancel(ctx)
done2()
scenario1a := newScenario(t, dctx, "ollama-model-1", 10)
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
scenario1a.req.sessionDuration = 0
s := InitScheduler(ctx)
slog.Info("scenario1a")
s.pendingReqCh <- scenario1a.req
@@ -642,8 +630,8 @@ type mockLlm struct {
pingResp error
waitResp error
completionResp error
embedResp [][]float32
embedRespErr error
embeddingResp []float64
embeddingRespErr error
tokenizeResp []int
tokenizeRespErr error
detokenizeResp string
@@ -660,8 +648,8 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
return s.completionResp
}
func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) {
return s.embedResp, s.embedRespErr
func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
return s.embeddingResp, s.embeddingRespErr
}
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
return s.tokenizeResp, s.tokenizeRespErr

View File

@@ -1,67 +0,0 @@
{{- if or .Tools .System }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>
{{- if .Tools }}# Safety Preamble
The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.
# System Preamble
## Basic Rules
You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.
{{ if .System }}# User Preamble
{{ .System }}
{{- end }}
## Available Tools
Here is a list of tools that you have available to you:
{{- range .Tools }}
```python
def {{ .Function.Name }}(
{{- range $name, $property := .Function.Parameters.Properties }}{{ $name }}: {{ $property.Type }}, {{ end }}) -> List[Dict]:
"""{{ .Function.Description }}
{{- if .Function.Parameters.Properties }}
Args:
{{- range $name, $property := .Function.Parameters.Properties }}
{{ $name }} ({{ $property.Type }}): {{ $property.Description }}
{{- end }}
{{- end }}
"""
pass
```
{{- end }}
{{- else if .System }}{{ .System }}
{{- end }}<|END_OF_TURN_TOKEN|>
{{- end }}
{{- range .Messages }}
{{- if eq .Role "system" }}
{{- continue }}
{{- end }}<|START_OF_TURN_TOKEN|>
{{- if eq .Role "user" }}<|USER_TOKEN|>{{ .Content }}
{{- else if eq .Role "assistant" }}<|CHATBOT_TOKEN|>
{{- if .Content }}{{ .Content }}
{{- else if .ToolCalls }}
Action: ```json
[
{{- range .ToolCalls }}
{
"tool_name": "{{ .Function.Name }}",
"parameters": {{ json .Function.Arguments }}
}
{{- end }}
]```
{{ continue }}
{{ end }}
{{- else if eq .Role "tool" }}<|SYSTEM_TOKEN|><results>
{{ .Content }}</results>
{{- end }}<|END_OF_TURN_TOKEN|>
{{- end }}
{{- if .Tools }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example:
```json
[
{
"tool_name": title of the tool in the specification,
"parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters
}
]```
{{- end }}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>

View File

@@ -1,39 +0,0 @@
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># Safety Preamble
The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.
# System Preamble
## Basic Rules
You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.
# User Preamble
You are a knowledgable assistant. You can answer questions and perform tasks.
## Available Tools
Here is a list of tools that you have available to you:
```python
def get_current_weather(format: string, location: string, ) -> List[Dict]:
"""Get the current weather
Args:
format (string): The temperature unit to use. Infer this from the users location.
location (string): The city and state, e.g. San Francisco, CA
"""
pass
```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in Paris?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
Action: ```json
[
{
"tool_name": "get_current_weather",
"parameters": {"format":"celsius","location":"Paris, France"}
}
]```
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><results>
22</results><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>The current temperature in Paris, France is 22 degrees Celsius.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in San Francisco and Toronto?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example:
```json
[
{
"tool_name": title of the tool in the specification,
"parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters
}
]```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>

View File

@@ -1,31 +0,0 @@
{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
{{- if .System }}
{{ .System }}
{{- end }}
In addition to plain text responses, you can chose to call one or more of the provided functions.
Use the following rule to decide when to call a function:
* if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
* if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
If you decide to call functions:
* prefix function calls with functools marker (no closing marker required)
* all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
* follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
* respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
* make sure you pick the right functions that match the user intent
Available functions as JSON spec:
{{- if .Tools }}
{{ json .Tools }}
{{- end }}<|eot_id|>
{{- end }}
{{- range .Messages }}<|start_header_id|>
{{- if or (eq .Role "user") (eq .Role "assistant") (eq .Role "tool") }}{{ .Role }}
{{- end }}<|end_header_id|>
{{- if .Content }}{{ .Content }}
{{- else if .ToolCalls }} functools[
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}{{ "}" }}
{{- end }}]
{{- end }}<|eot_id|>
{{- end }}<|start_header_id|>assistant<|end_header_id|>

View File

@@ -1,17 +0,0 @@
<|start_header_id|>system<|end_header_id|>
You are a knowledgable assistant. You can answer questions and perform tasks.
In addition to plain text responses, you can chose to call one or more of the provided functions.
Use the following rule to decide when to call a function:
* if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
* if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
If you decide to call functions:
* prefix function calls with functools marker (no closing marker required)
* all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
* follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
* respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
* make sure you pick the right functions that match the user intent
Available functions as JSON spec:
[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]<|eot_id|><|start_header_id|><|end_header_id|>You are a knowledgable assistant. You can answer questions and perform tasks.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|> functools[{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]<|eot_id|><|start_header_id|>tool<|end_header_id|>22<|eot_id|><|start_header_id|>assistant<|end_header_id|>The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

View File

@@ -1,39 +0,0 @@
[
{
"role": "system",
"content": "You are a knowledgable assistant. You can answer questions and perform tasks."
},
{
"role": "user",
"content": "What's the weather like today in Paris?"
},
{
"role": "assistant",
"tool_calls": [
{
"id": "89a1e453-0bce-4de3-a456-c54bed09c520",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": {
"location": "Paris, France",
"format": "celsius"
}
}
}
]
},
{
"role": "tool",
"tool_call_id": "89a1e453-0bce-4de3-a456-c54bed09c520",
"content": "22"
},
{
"role": "assistant",
"content": "The current temperature in Paris, France is 22 degrees Celsius."
},
{
"role": "user",
"content": "What's the weather like today in San Francisco and Toronto?"
}
]

View File

@@ -1,15 +0,0 @@
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ json $.Tools }}[/AVAILABLE_TOOLS]
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
{{ end }}{{ .Content }}[/INST]
{{- else if eq .Role "assistant" }}
{{- if .Content }} {{ .Content }}</s>
{{- else if .ToolCalls }}[TOOL_CALLS] [
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
{{- end }}]</s>
{{- end }}
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
{{- end }}
{{- end }}

View File

@@ -1,3 +0,0 @@
[INST] What's the weather like today in Paris?[/INST][TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]</s>[TOOL_RESULTS] {"content": 22}[/TOOL_RESULTS] The current temperature in Paris, France is 22 degrees Celsius.</s>[AVAILABLE_TOOLS] [{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}][/AVAILABLE_TOOLS][INST] You are a knowledgable assistant. You can answer questions and perform tasks.
What's the weather like today in San Francisco and Toronto?[/INST]

View File

@@ -1,30 +0,0 @@
[
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"format": {
"type": "string",
"enum": [
"celsius",
"fahrenheit"
],
"description": "The temperature unit to use. Infer this from the users location."
}
},
"required": [
"location",
"format"
]
}
}
}
]

View File

@@ -1,10 +0,0 @@
{{ if .System }}Source: system
{{ .System }} <step> {{ end }}Source: user
{{ .Prompt }} <step> Source: assistant
{{- if not .Response }}
Destination: user
{{- end }}
{{ .Response }} <step>

View File

@@ -1,5 +0,0 @@
{{ if .System }}System: {{ .System }}
{{ end }}{{ if .Prompt }}User:
{{ .Prompt }}
{{ end }}Falcon:
{{ .Response }}

View File

@@ -1,5 +0,0 @@
<start_of_turn>user
{{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}<end_of_turn>
<start_of_turn>model
{{ .Response }}<end_of_turn>

View File

@@ -1,6 +0,0 @@
[INST] <<SYS>>
{{- if .System }}
{{ .System }}
{{ end }}<</SYS>>
{{ .Prompt }} [/INST] {{ .Response }}</s><s>

View File

@@ -1,3 +0,0 @@
[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}</s>

View File

@@ -1 +0,0 @@
{{ if .System }}GPT4 Correct System: {{ .System }}<|end_of_turn|>{{ end }}GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>GPT4 Correct Assistant: {{ .Response }}<|end_of_turn|>

View File

@@ -1,429 +0,0 @@
package template
import (
"bytes"
"embed"
"encoding/json"
"errors"
"fmt"
"io"
"math"
"slices"
"strings"
"sync"
"text/template"
"text/template/parse"
"github.com/agnivade/levenshtein"
"github.com/ollama/ollama/api"
"golang.org/x/exp/maps"
)
//go:embed index.json
var indexBytes []byte
//go:embed *.gotmpl
var templatesFS embed.FS
var templatesOnce = sync.OnceValues(func() ([]*named, error) {
var templates []*named
if err := json.Unmarshal(indexBytes, &templates); err != nil {
return nil, err
}
for _, t := range templates {
bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
if err != nil {
return nil, err
}
// normalize line endings
t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
}
return templates, nil
})
type named struct {
Name string `json:"name"`
Template string `json:"template"`
Bytes []byte
}
func (t named) Reader() io.Reader {
return bytes.NewReader(t.Bytes)
}
func Named(s string) (*named, error) {
templates, err := templatesOnce()
if err != nil {
return nil, err
}
var template *named
score := math.MaxInt
for _, t := range templates {
if s := levenshtein.ComputeDistance(s, t.Template); s < score {
score = s
template = t
}
}
if score < 100 {
return template, nil
}
return nil, errors.New("no matching template found")
}
var DefaultTemplate, _ = Parse("{{ .Prompt }}")
type Template struct {
*template.Template
raw string
}
// response is a template node that can be added to templates that don't already have one
var response = parse.ActionNode{
NodeType: parse.NodeAction,
Pipe: &parse.PipeNode{
NodeType: parse.NodePipe,
Cmds: []*parse.CommandNode{
{
NodeType: parse.NodeCommand,
Args: []parse.Node{
&parse.FieldNode{
NodeType: parse.NodeField,
Ident: []string{"Response"},
},
},
},
},
},
}
var funcs = template.FuncMap{
"json": func(v any) string {
b, _ := json.Marshal(v)
return string(b)
},
}
func Parse(s string) (*Template, error) {
tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
tmpl, err := tmpl.Parse(s)
if err != nil {
return nil, err
}
t := Template{Template: tmpl, raw: s}
if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
// touch up the template and append {{ .Response }}
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
}
return &t, nil
}
func (t *Template) String() string {
return t.raw
}
func (t *Template) Vars() []string {
var vars []string
for _, tt := range t.Templates() {
for _, n := range tt.Root.Nodes {
vars = append(vars, Identifiers(n)...)
}
}
set := make(map[string]struct{})
for _, n := range vars {
set[strings.ToLower(n)] = struct{}{}
}
vars = maps.Keys(set)
slices.Sort(vars)
return vars
}
type Values struct {
Messages []api.Message
Tools []api.Tool
Prompt string
Suffix string
// forceLegacy is a flag used to test compatibility with legacy templates
forceLegacy bool
}
func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
var walk func(parse.Node) parse.Node
walk = func(n parse.Node) parse.Node {
if fn(n) {
return n
}
switch t := n.(type) {
case *parse.ListNode:
for _, c := range t.Nodes {
if n := walk(c); n != nil {
return n
}
}
case *parse.BranchNode:
for _, n := range []*parse.ListNode{t.List, t.ElseList} {
if n != nil {
if n := walk(n); n != nil {
return n
}
}
}
case *parse.IfNode:
return walk(&t.BranchNode)
case *parse.WithNode:
return walk(&t.BranchNode)
case *parse.RangeNode:
return walk(&t.BranchNode)
}
return nil
}
if n := walk(t.Tree.Root); n != nil {
return (&template.Template{
Tree: &parse.Tree{
Root: &parse.ListNode{
Nodes: []parse.Node{n},
},
},
}).Funcs(funcs)
}
return nil
}
func (t *Template) Execute(w io.Writer, v Values) error {
system, messages := collate(v.Messages)
if v.Prompt != "" && v.Suffix != "" {
return t.Template.Execute(w, map[string]any{
"Prompt": v.Prompt,
"Suffix": v.Suffix,
"Response": "",
})
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
return t.Template.Execute(w, map[string]any{
"System": system,
"Messages": messages,
"Tools": v.Tools,
})
}
system = ""
var b bytes.Buffer
var prompt, response string
for _, m := range messages {
execute := func() error {
if err := t.Template.Execute(&b, map[string]any{
"System": system,
"Prompt": prompt,
"Response": response,
}); err != nil {
return err
}
system = ""
prompt = ""
response = ""
return nil
}
switch m.Role {
case "system":
if prompt != "" || response != "" {
if err := execute(); err != nil {
return err
}
}
system = m.Content
case "user":
if response != "" {
if err := execute(); err != nil {
return err
}
}
prompt = m.Content
case "assistant":
response = m.Content
}
}
var cut bool
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
cut = true
}
return cut
})
tree := parse.Tree{Root: nodes.(*parse.ListNode)}
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
"System": system,
"Prompt": prompt,
}); err != nil {
return err
}
_, err := io.Copy(w, &b)
return err
}
// collate messages based on role. consecutive messages of the same role are merged
// into a single message. collate also collects and returns all system messages.
// collate mutates message content adding image tags ([img-%d]) as needed
func collate(msgs []api.Message) (string, []*api.Message) {
var n int
var system []string
var collated []*api.Message
for i := range msgs {
msg := msgs[i]
for range msg.Images {
imageTag := fmt.Sprintf("[img-%d]", n)
if !strings.Contains(msg.Content, "[img]") {
msg.Content = strings.TrimSpace("[img] " + msg.Content)
}
msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1)
n++
}
if msg.Role == "system" {
system = append(system, msg.Content)
}
if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
collated[len(collated)-1].Content += "\n\n" + msg.Content
} else {
collated = append(collated, &msg)
}
}
return strings.Join(system, "\n\n"), collated
}
// Identifiers walks the node tree returning any identifiers it finds along the way
func Identifiers(n parse.Node) []string {
switch n := n.(type) {
case *parse.ListNode:
var names []string
for _, n := range n.Nodes {
names = append(names, Identifiers(n)...)
}
return names
case *parse.TemplateNode:
return Identifiers(n.Pipe)
case *parse.ActionNode:
return Identifiers(n.Pipe)
case *parse.BranchNode:
names := Identifiers(n.Pipe)
for _, n := range []*parse.ListNode{n.List, n.ElseList} {
if n != nil {
names = append(names, Identifiers(n)...)
}
}
return names
case *parse.IfNode:
return Identifiers(&n.BranchNode)
case *parse.RangeNode:
return Identifiers(&n.BranchNode)
case *parse.WithNode:
return Identifiers(&n.BranchNode)
case *parse.PipeNode:
var names []string
for _, c := range n.Cmds {
for _, a := range c.Args {
names = append(names, Identifiers(a)...)
}
}
return names
case *parse.FieldNode:
return n.Ident
case *parse.VariableNode:
return n.Ident
}
return nil
}
// deleteNode walks the node list and deletes nodes that match the predicate
// this is currently to remove the {{ .Response }} node from templates
func deleteNode(n parse.Node, fn func(parse.Node) bool) parse.Node {
var walk func(n parse.Node) parse.Node
walk = func(n parse.Node) parse.Node {
if fn(n) {
return nil
}
switch t := n.(type) {
case *parse.ListNode:
var nodes []parse.Node
for _, c := range t.Nodes {
if n := walk(c); n != nil {
nodes = append(nodes, n)
}
}
t.Nodes = nodes
return t
case *parse.IfNode:
t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
case *parse.WithNode:
t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
case *parse.RangeNode:
t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
case *parse.BranchNode:
t.List = walk(t.List).(*parse.ListNode)
if t.ElseList != nil {
t.ElseList = walk(t.ElseList).(*parse.ListNode)
}
case *parse.ActionNode:
n := walk(t.Pipe)
if n == nil {
return nil
}
t.Pipe = n.(*parse.PipeNode)
case *parse.PipeNode:
var commands []*parse.CommandNode
for _, c := range t.Cmds {
var args []parse.Node
for _, a := range c.Args {
if n := walk(a); n != nil {
args = append(args, n)
}
}
if len(args) == 0 {
return nil
}
c.Args = args
commands = append(commands, c)
}
if len(commands) == 0 {
return nil
}
t.Cmds = commands
}
return n
}
return walk(n)
}

View File

@@ -1,396 +0,0 @@
package template
import (
"bufio"
"bytes"
"encoding/json"
"io"
"os"
"path/filepath"
"slices"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
)
func TestNamed(t *testing.T) {
f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
var ss map[string]string
if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil {
t.Fatal(err)
}
for k, v := range ss {
t.Run(k, func(t *testing.T) {
kv := llm.KV{"tokenizer.chat_template": v}
s := kv.ChatTemplate()
r, err := Named(s)
if err != nil {
t.Fatal(err)
}
if r.Name != k {
t.Errorf("expected %q, got %q", k, r.Name)
}
var b bytes.Buffer
if _, err := io.Copy(&b, r.Reader()); err != nil {
t.Fatal(err)
}
tmpl, err := Parse(b.String())
if err != nil {
t.Fatal(err)
}
if tmpl.Tree.Root.String() == "" {
t.Errorf("empty %s template", k)
}
})
}
}
}
func TestTemplate(t *testing.T) {
cases := make(map[string][]api.Message)
for _, mm := range [][]api.Message{
{
{Role: "user", Content: "Hello, how are you?"},
},
{
{Role: "user", Content: "Hello, how are you?"},
{Role: "assistant", Content: "I'm doing great. How can I help you today?"},
{Role: "user", Content: "I'd like to show off how chat templating works!"},
},
{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello, how are you?"},
{Role: "assistant", Content: "I'm doing great. How can I help you today?"},
{Role: "user", Content: "I'd like to show off how chat templating works!"},
},
} {
var roles []string
for _, m := range mm {
roles = append(roles, m.Role)
}
cases[strings.Join(roles, "-")] = mm
}
matches, err := filepath.Glob("*.gotmpl")
if err != nil {
t.Fatal(err)
}
for _, match := range matches {
t.Run(match, func(t *testing.T) {
bts, err := os.ReadFile(match)
if err != nil {
t.Fatal(err)
}
tmpl, err := Parse(string(bts))
if err != nil {
t.Fatal(err)
}
for n, tt := range cases {
var actual bytes.Buffer
t.Run(n, func(t *testing.T) {
if err := tmpl.Execute(&actual, Values{Messages: tt}); err != nil {
t.Fatal(err)
}
expect, err := os.ReadFile(filepath.Join("testdata", match, n))
if err != nil {
t.Fatal(err)
}
bts := actual.Bytes()
if slices.Contains([]string{"chatqa.gotmpl", "llama2-chat.gotmpl", "mistral-instruct.gotmpl", "openchat.gotmpl", "vicuna.gotmpl"}, match) && bts[len(bts)-1] == ' ' {
t.Log("removing trailing space from output")
bts = bts[:len(bts)-1]
}
if diff := cmp.Diff(bts, expect); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("legacy", func(t *testing.T) {
t.Skip("legacy outputs are currently default outputs")
var legacy bytes.Buffer
if err := tmpl.Execute(&legacy, Values{Messages: tt, forceLegacy: true}); err != nil {
t.Fatal(err)
}
legacyBytes := legacy.Bytes()
if slices.Contains([]string{"chatqa.gotmpl", "openchat.gotmpl", "vicuna.gotmpl"}, match) && legacyBytes[len(legacyBytes)-1] == ' ' {
t.Log("removing trailing space from legacy output")
legacyBytes = legacyBytes[:len(legacyBytes)-1]
} else if slices.Contains([]string{"codellama-70b-instruct.gotmpl", "llama2-chat.gotmpl", "mistral-instruct.gotmpl"}, match) {
t.Skip("legacy outputs cannot be compared to messages outputs")
}
if diff := cmp.Diff(legacyBytes, actual.Bytes()); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
})
}
}
func TestParse(t *testing.T) {
cases := []struct {
template string
vars []string
}{
{"{{ .Prompt }}", []string{"prompt", "response"}},
{"{{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system"}},
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
{`{{- range .Messages }}
{{- if eq .Role "system" }}SYSTEM:
{{- else if eq .Role "user" }}USER:
{{- else if eq .Role "assistant" }}ASSISTANT:
{{- end }} {{ .Content }}
{{- end }}`, []string{"content", "messages", "role"}},
{`{{- if .Messages }}
{{- range .Messages }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ else -}}
{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
{{- end -}}`, []string{"content", "messages", "prompt", "response", "role", "system"}},
}
for _, tt := range cases {
t.Run("", func(t *testing.T) {
tmpl, err := Parse(tt.template)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(tmpl.Vars(), tt.vars); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
}
func TestExecuteWithMessages(t *testing.T) {
type template struct {
name string
template string
}
cases := []struct {
name string
templates []template
values Values
expected string
}{
{
"mistral",
[]template{
{"no response", `[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] `},
{"response", `[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `[INST] {{ if .System }}{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
{{- end }}`},
},
Values{
Messages: []api.Message{
{Role: "user", Content: "Hello friend!"},
{Role: "assistant", Content: "Hello human!"},
{Role: "user", Content: "What is your name?"},
},
},
`[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
},
{
"mistral system",
[]template{
{"no response", `[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] `},
{"response", `[INST] {{ if .System }}{{ .System }}
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `[INST] {{ if .System }}{{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
{{- end }}`},
},
Values{
Messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant!"},
{Role: "user", Content: "Hello friend!"},
{Role: "assistant", Content: "Hello human!"},
{Role: "user", Content: "What is your name?"},
},
},
`[INST] You are a helpful assistant!
Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
},
{
"chatml",
[]template{
// this does not have a "no response" test because it's impossible to render the same output
{"response", `{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
`},
{"messages", `
{{- range $index, $_ := .Messages }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>
{{ end }}<|im_start|>assistant
`},
},
Values{
Messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant!"},
{Role: "user", Content: "Hello friend!"},
{Role: "assistant", Content: "Hello human!"},
{Role: "user", Content: "What is your name?"},
},
},
`<|im_start|>system
You are a helpful assistant!<|im_end|>
<|im_start|>user
Hello friend!<|im_end|>
<|im_start|>assistant
Hello human!<|im_end|>
<|im_start|>user
What is your name?<|im_end|>
<|im_start|>assistant
`,
},
{
"moondream",
[]template{
// this does not have a "no response" test because it's impossible to render the same output
{"response", `{{ if .Prompt }}Question: {{ .Prompt }}
{{ end }}Answer: {{ .Response }}
`},
{"messages", `
{{- range .Messages }}
{{- if eq .Role "user" }}Question: {{ .Content }}
{{ else if eq .Role "assistant" }}Answer: {{ .Content }}
{{ end }}
{{- end }}Answer: `},
},
Values{
Messages: []api.Message{
{Role: "user", Content: "What's in this image?", Images: []api.ImageData{[]byte("")}},
{Role: "assistant", Content: "It's a hot dog."},
{Role: "user", Content: "What's in _this_ image?"},
{Role: "user", Images: []api.ImageData{[]byte("")}},
{Role: "user", Content: "Is it a hot dog?"},
},
},
`Question: [img-0] What's in this image?
Answer: It's a hot dog.
Question: What's in _this_ image?
[img-1]
Is it a hot dog?
Answer: `,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
for _, ttt := range tt.templates {
t.Run(ttt.name, func(t *testing.T) {
tmpl, err := Parse(ttt.template)
if err != nil {
t.Fatal(err)
}
var b bytes.Buffer
if err := tmpl.Execute(&b, tt.values); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(b.String(), tt.expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
})
}
}
func TestExecuteWithSuffix(t *testing.T) {
tmpl, err := Parse(`{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
{{- else }}{{ .Prompt }}
{{- end }}`)
if err != nil {
t.Fatal(err)
}
cases := []struct {
name string
values Values
expect string
}{
{
"message", Values{Messages: []api.Message{{Role: "user", Content: "hello"}}}, "hello",
},
{
"prompt suffix", Values{Prompt: "def add(", Suffix: "return x"}, "<PRE> def add( <SUF>return x <MID>",
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
var b bytes.Buffer
if err := tmpl.Execute(&b, tt.values); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(b.String(), tt.expect); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
}

View File

@@ -1 +0,0 @@
<start_system>You are a helpful assistant.<end_message><start_user>Hello, how are you?<end_message><start_assistant>I'm doing great. How can I help you today?<end_message><start_user>I'd like to show off how chat templating works!<end_message><start_assistant>

View File

@@ -1 +0,0 @@
<start_user>Hello, how are you?<end_message><start_assistant>

View File

@@ -1 +0,0 @@
<start_user>Hello, how are you?<end_message><start_assistant>I'm doing great. How can I help you today?<end_message><start_user>I'd like to show off how chat templating works!<end_message><start_assistant>

View File

@@ -1,12 +0,0 @@
You are a helpful assistant.
### Instruction:
Hello, how are you?
### Response:
I'm doing great. How can I help you today?
### Instruction:
I'd like to show off how chat templating works!
### Response:

View File

@@ -1,4 +0,0 @@
### Instruction:
Hello, how are you?
### Response:

View File

@@ -1,10 +0,0 @@
### Instruction:
Hello, how are you?
### Response:
I'm doing great. How can I help you today?
### Instruction:
I'd like to show off how chat templating works!
### Response:

View File

@@ -1,9 +0,0 @@
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Hello, how are you?<|im_end|>
<|im_start|>assistant
I'm doing great. How can I help you today?<|im_end|>
<|im_start|>user
I'd like to show off how chat templating works!<|im_end|>
<|im_start|>assistant

View File

@@ -1,3 +0,0 @@
<|im_start|>user
Hello, how are you?<|im_end|>
<|im_start|>assistant

View File

@@ -1,7 +0,0 @@
<|im_start|>user
Hello, how are you?<|im_end|>
<|im_start|>assistant
I'm doing great. How can I help you today?<|im_end|>
<|im_start|>user
I'd like to show off how chat templating works!<|im_end|>
<|im_start|>assistant

Some files were not shown because too many files have changed in this diff Show More