Merge branch 'main' into royh-batchembed
This commit is contained in:
commit
a5f23d766e
@ -70,12 +70,12 @@ RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
|
|||||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64
|
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64
|
||||||
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
|
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
|
||||||
|
|
||||||
FROM --platform=linux/arm64 centos:7 AS cpu-builder-arm64
|
FROM --platform=linux/arm64 rockylinux:8 AS cpu-builder-arm64
|
||||||
ARG CMAKE_VERSION
|
ARG CMAKE_VERSION
|
||||||
ARG GOLANG_VERSION
|
ARG GOLANG_VERSION
|
||||||
COPY ./scripts/rh_linux_deps.sh /
|
COPY ./scripts/rh_linux_deps.sh /
|
||||||
RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
|
RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
|
||||||
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
|
ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH
|
||||||
COPY --from=llm-code / /go/src/github.com/ollama/ollama/
|
COPY --from=llm-code / /go/src/github.com/ollama/ollama/
|
||||||
ARG OLLAMA_CUSTOM_CPU_DEFS
|
ARG OLLAMA_CUSTOM_CPU_DEFS
|
||||||
ARG CGO_CFLAGS
|
ARG CGO_CFLAGS
|
||||||
|
11
README.md
11
README.md
@ -53,8 +53,8 @@ Here are some example models that can be downloaded:
|
|||||||
| Llama 3 | 70B | 40GB | `ollama run llama3:70b` |
|
| Llama 3 | 70B | 40GB | `ollama run llama3:70b` |
|
||||||
| Phi 3 Mini | 3.8B | 2.3GB | `ollama run phi3` |
|
| Phi 3 Mini | 3.8B | 2.3GB | `ollama run phi3` |
|
||||||
| Phi 3 Medium | 14B | 7.9GB | `ollama run phi3:medium` |
|
| Phi 3 Medium | 14B | 7.9GB | `ollama run phi3:medium` |
|
||||||
| Gemma | 2B | 1.4GB | `ollama run gemma:2b` |
|
| Gemma 2 | 9B | 5.5GB | `ollama run gemma2` |
|
||||||
| Gemma | 7B | 4.8GB | `ollama run gemma:7b` |
|
| Gemma 2 | 27B | 16GB | `ollama run gemma2:27b` |
|
||||||
| Mistral | 7B | 4.1GB | `ollama run mistral` |
|
| Mistral | 7B | 4.1GB | `ollama run mistral` |
|
||||||
| Moondream 2 | 1.4B | 829MB | `ollama run moondream` |
|
| Moondream 2 | 1.4B | 829MB | `ollama run moondream` |
|
||||||
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
|
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
|
||||||
@ -182,6 +182,12 @@ $ ollama run llama3 "Summarize this file: $(cat README.md)"
|
|||||||
Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
|
Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Show model information
|
||||||
|
|
||||||
|
```
|
||||||
|
ollama show llama3
|
||||||
|
```
|
||||||
|
|
||||||
### List models on your computer
|
### List models on your computer
|
||||||
|
|
||||||
```
|
```
|
||||||
@ -286,6 +292,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
|
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
|
||||||
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
|
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
|
||||||
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
|
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
|
||||||
|
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
|
||||||
|
|
||||||
### Terminal
|
### Terminal
|
||||||
|
|
||||||
|
23
api/types.go
23
api/types.go
@ -277,6 +277,7 @@ type ShowRequest struct {
|
|||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
System string `json:"system"`
|
System string `json:"system"`
|
||||||
Template string `json:"template"`
|
Template string `json:"template"`
|
||||||
|
Verbose bool `json:"verbose"`
|
||||||
|
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
|
|
||||||
@ -293,6 +294,8 @@ type ShowResponse struct {
|
|||||||
System string `json:"system,omitempty"`
|
System string `json:"system,omitempty"`
|
||||||
Details ModelDetails `json:"details,omitempty"`
|
Details ModelDetails `json:"details,omitempty"`
|
||||||
Messages []Message `json:"messages,omitempty"`
|
Messages []Message `json:"messages,omitempty"`
|
||||||
|
ModelInfo map[string]any `json:"model_info,omitempty"`
|
||||||
|
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
||||||
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -366,6 +369,13 @@ type ProcessModelResponse struct {
|
|||||||
SizeVRAM int64 `json:"size_vram"`
|
SizeVRAM int64 `json:"size_vram"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RetrieveModelResponse struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
OwnedBy string `json:"owned_by"`
|
||||||
|
}
|
||||||
|
|
||||||
type TokenResponse struct {
|
type TokenResponse struct {
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
}
|
}
|
||||||
@ -629,6 +639,19 @@ func FormatParams(params map[string][]string) (map[string]interface{}, error) {
|
|||||||
} else {
|
} else {
|
||||||
field := valueOpts.FieldByName(opt.Name)
|
field := valueOpts.FieldByName(opt.Name)
|
||||||
if field.IsValid() && field.CanSet() {
|
if field.IsValid() && field.CanSet() {
|
||||||
|
if reflect.PointerTo(field.Type()) == reflect.TypeOf((*TriState)(nil)) {
|
||||||
|
boolVal, err := strconv.ParseBool(vals[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid bool value %s", vals)
|
||||||
|
}
|
||||||
|
if boolVal {
|
||||||
|
out[key] = TriStateTrue
|
||||||
|
} else {
|
||||||
|
out[key] = TriStateFalse
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
switch field.Kind() {
|
switch field.Kind() {
|
||||||
case reflect.Float32:
|
case reflect.Float32:
|
||||||
floatVal, err := strconv.ParseFloat(vals[0], 32)
|
floatVal, err := strconv.ParseFloat(vals[0], 32)
|
||||||
|
@ -2,6 +2,7 @@ package api
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -141,3 +142,65 @@ func TestUseMmapParsingFromJSON(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUseMmapFormatParams(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
req map[string][]string
|
||||||
|
exp TriState
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "True",
|
||||||
|
req: map[string][]string{
|
||||||
|
"use_mmap": []string{"true"},
|
||||||
|
},
|
||||||
|
exp: TriStateTrue,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "False",
|
||||||
|
req: map[string][]string{
|
||||||
|
"use_mmap": []string{"false"},
|
||||||
|
},
|
||||||
|
exp: TriStateFalse,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Numeric True",
|
||||||
|
req: map[string][]string{
|
||||||
|
"use_mmap": []string{"1"},
|
||||||
|
},
|
||||||
|
exp: TriStateTrue,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Numeric False",
|
||||||
|
req: map[string][]string{
|
||||||
|
"use_mmap": []string{"0"},
|
||||||
|
},
|
||||||
|
exp: TriStateFalse,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid string",
|
||||||
|
req: map[string][]string{
|
||||||
|
"use_mmap": []string{"foo"},
|
||||||
|
},
|
||||||
|
exp: TriStateUndefined,
|
||||||
|
err: fmt.Errorf("invalid bool value [foo]"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
resp, err := FormatParams(test.req)
|
||||||
|
require.Equal(t, err, test.err)
|
||||||
|
respVal, ok := resp["use_mmap"]
|
||||||
|
if test.exp != TriStateUndefined {
|
||||||
|
assert.True(t, ok, "resp: %v", resp)
|
||||||
|
assert.Equal(t, test.exp, respVal)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -5,6 +5,8 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
)
|
)
|
||||||
@ -24,6 +26,7 @@ func InitLogging() {
|
|||||||
logFile = os.Stderr
|
logFile = os.Stderr
|
||||||
// TODO - write one-line to the app.log file saying we're running in console mode to help avoid confusion
|
// TODO - write one-line to the app.log file saying we're running in console mode to help avoid confusion
|
||||||
} else {
|
} else {
|
||||||
|
rotateLogs(AppLogFile)
|
||||||
logFile, err = os.OpenFile(AppLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755)
|
logFile, err = os.OpenFile(AppLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error(fmt.Sprintf("failed to create server log %v", err))
|
slog.Error(fmt.Sprintf("failed to create server log %v", err))
|
||||||
@ -46,3 +49,32 @@ func InitLogging() {
|
|||||||
|
|
||||||
slog.Info("ollama app started")
|
slog.Info("ollama app started")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func rotateLogs(logFile string) {
|
||||||
|
if _, err := os.Stat(logFile); os.IsNotExist(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
index := strings.LastIndex(logFile, ".")
|
||||||
|
pre := logFile[:index]
|
||||||
|
post := "." + logFile[index+1:]
|
||||||
|
for i := LogRotationCount; i > 0; i-- {
|
||||||
|
older := pre + "-" + strconv.Itoa(i) + post
|
||||||
|
newer := pre + "-" + strconv.Itoa(i-1) + post
|
||||||
|
if i == 1 {
|
||||||
|
newer = pre + post
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(newer); err == nil {
|
||||||
|
if _, err := os.Stat(older); err == nil {
|
||||||
|
err := os.Remove(older)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("Failed to remove older log", "older", older, "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err := os.Rename(newer, older)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("Failed to rotate log", "older", older, "newer", newer, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
44
app/lifecycle/logging_test.go
Normal file
44
app/lifecycle/logging_test.go
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
package lifecycle
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRotateLogs(t *testing.T) {
|
||||||
|
logDir := t.TempDir()
|
||||||
|
logFile := filepath.Join(logDir, "testlog.log")
|
||||||
|
|
||||||
|
// No log exists
|
||||||
|
rotateLogs(logFile)
|
||||||
|
|
||||||
|
require.NoError(t, os.WriteFile(logFile, []byte("1"), 0644))
|
||||||
|
assert.FileExists(t, logFile)
|
||||||
|
// First rotation
|
||||||
|
rotateLogs(logFile)
|
||||||
|
assert.FileExists(t, filepath.Join(logDir, "testlog-1.log"))
|
||||||
|
assert.NoFileExists(t, filepath.Join(logDir, "testlog-2.log"))
|
||||||
|
assert.NoFileExists(t, logFile)
|
||||||
|
|
||||||
|
// Should be a no-op without a new log
|
||||||
|
rotateLogs(logFile)
|
||||||
|
assert.FileExists(t, filepath.Join(logDir, "testlog-1.log"))
|
||||||
|
assert.NoFileExists(t, filepath.Join(logDir, "testlog-2.log"))
|
||||||
|
assert.NoFileExists(t, logFile)
|
||||||
|
|
||||||
|
for i := 2; i <= LogRotationCount+1; i++ {
|
||||||
|
require.NoError(t, os.WriteFile(logFile, []byte(strconv.Itoa(i)), 0644))
|
||||||
|
assert.FileExists(t, logFile)
|
||||||
|
rotateLogs(logFile)
|
||||||
|
assert.NoFileExists(t, logFile)
|
||||||
|
for j := 1; j < i; j++ {
|
||||||
|
assert.FileExists(t, filepath.Join(logDir, "testlog-"+strconv.Itoa(j)+".log"))
|
||||||
|
}
|
||||||
|
assert.NoFileExists(t, filepath.Join(logDir, "testlog-"+strconv.Itoa(i+1)+".log"))
|
||||||
|
}
|
||||||
|
}
|
@ -21,6 +21,7 @@ var (
|
|||||||
ServerLogFile = "/tmp/ollama.log"
|
ServerLogFile = "/tmp/ollama.log"
|
||||||
UpgradeLogFile = "/tmp/ollama_update.log"
|
UpgradeLogFile = "/tmp/ollama_update.log"
|
||||||
Installer = "OllamaSetup.exe"
|
Installer = "OllamaSetup.exe"
|
||||||
|
LogRotationCount = 5
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -54,7 +54,7 @@ func start(ctx context.Context, command string) (*exec.Cmd, error) {
|
|||||||
return nil, fmt.Errorf("failed to spawn server stderr pipe: %w", err)
|
return nil, fmt.Errorf("failed to spawn server stderr pipe: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO - rotation
|
rotateLogs(ServerLogFile)
|
||||||
logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755)
|
logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create server log: %w", err)
|
return nil, fmt.Errorf("failed to create server log: %w", err)
|
||||||
|
@ -88,10 +88,15 @@ DialogFontSize=12
|
|||||||
[Files]
|
[Files]
|
||||||
Source: ".\app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ; Flags: ignoreversion 64bit
|
Source: ".\app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ; Flags: ignoreversion 64bit
|
||||||
Source: "..\ollama.exe"; DestDir: "{app}"; Flags: ignoreversion 64bit
|
Source: "..\ollama.exe"; DestDir: "{app}"; Flags: ignoreversion 64bit
|
||||||
Source: "..\dist\windows-{#ARCH}\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit
|
|
||||||
Source: "..\dist\windows-{#ARCH}\ollama_runners\*"; DestDir: "{app}\ollama_runners"; Flags: ignoreversion 64bit recursesubdirs
|
Source: "..\dist\windows-{#ARCH}\ollama_runners\*"; DestDir: "{app}\ollama_runners"; Flags: ignoreversion 64bit recursesubdirs
|
||||||
Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion
|
Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion
|
||||||
Source: ".\assets\app.ico"; DestDir: "{app}"; Flags: ignoreversion
|
Source: ".\assets\app.ico"; DestDir: "{app}"; Flags: ignoreversion
|
||||||
|
#if DirExists("..\dist\windows-amd64\cuda")
|
||||||
|
Source: "..\dist\windows-amd64\cuda\*"; DestDir: "{app}\cuda\"; Flags: ignoreversion recursesubdirs
|
||||||
|
#endif
|
||||||
|
#if DirExists("..\dist\windows-amd64\oneapi")
|
||||||
|
Source: "..\dist\windows-amd64\oneapi\*"; DestDir: "{app}\oneapi\"; Flags: ignoreversion recursesubdirs
|
||||||
|
#endif
|
||||||
#if DirExists("..\dist\windows-amd64\rocm")
|
#if DirExists("..\dist\windows-amd64\rocm")
|
||||||
Source: "..\dist\windows-amd64\rocm\*"; DestDir: "{app}\rocm\"; Flags: ignoreversion recursesubdirs
|
Source: "..\dist\windows-amd64\rocm\*"; DestDir: "{app}\rocm\"; Flags: ignoreversion recursesubdirs
|
||||||
#endif
|
#endif
|
||||||
|
188
cmd/cmd.go
188
cmd/cmd.go
@ -162,9 +162,6 @@ func tempZipFiles(path string) (string, error) {
|
|||||||
}
|
}
|
||||||
defer tempfile.Close()
|
defer tempfile.Close()
|
||||||
|
|
||||||
zipfile := zip.NewWriter(tempfile)
|
|
||||||
defer zipfile.Close()
|
|
||||||
|
|
||||||
detectContentType := func(path string) (string, error) {
|
detectContentType := func(path string) (string, error) {
|
||||||
f, err := os.Open(path)
|
f, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -233,6 +230,9 @@ func tempZipFiles(path string) (string, error) {
|
|||||||
files = append(files, tks...)
|
files = append(files, tks...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
zipfile := zip.NewWriter(tempfile)
|
||||||
|
defer zipfile.Close()
|
||||||
|
|
||||||
for _, file := range files {
|
for _, file := range files {
|
||||||
f, err := os.Open(file)
|
f, err := os.Open(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -287,38 +287,12 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||||
client, err := api.ClientFromEnvironment()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
name := args[0]
|
|
||||||
|
|
||||||
// check if the model exists on the server
|
|
||||||
show, err := client.Show(cmd.Context(), &api.ShowRequest{Name: name})
|
|
||||||
var statusError api.StatusError
|
|
||||||
switch {
|
|
||||||
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
|
|
||||||
if err := PullHandler(cmd, []string{name}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
show, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
case err != nil:
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
interactive := true
|
interactive := true
|
||||||
|
|
||||||
opts := runOptions{
|
opts := runOptions{
|
||||||
Model: args[0],
|
Model: args[0],
|
||||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||||
Options: map[string]interface{}{},
|
Options: map[string]interface{}{},
|
||||||
MultiModal: slices.Contains(show.Details.Families, "clip"),
|
|
||||||
ParentModel: show.Details.ParentModel,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
format, err := cmd.Flags().GetString("format")
|
format, err := cmd.Flags().GetString("format")
|
||||||
@ -362,11 +336,38 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
opts.WordWrap = !nowrap
|
opts.WordWrap = !nowrap
|
||||||
|
|
||||||
if !interactive {
|
// Fill out the rest of the options based on information about the
|
||||||
return generate(cmd, opts)
|
// model.
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
name := args[0]
|
||||||
|
info, err := func() (*api.ShowResponse, error) {
|
||||||
|
showReq := &api.ShowRequest{Name: name}
|
||||||
|
info, err := client.Show(cmd.Context(), showReq)
|
||||||
|
var se api.StatusError
|
||||||
|
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
|
||||||
|
if err := PullHandler(cmd, []string{name}); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return client.Show(cmd.Context(), &api.ShowRequest{Name: name})
|
||||||
|
}
|
||||||
|
return info, err
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
opts.MultiModal = slices.Contains(info.Details.Families, "clip")
|
||||||
|
opts.ParentModel = info.Details.ParentModel
|
||||||
|
opts.Messages = append(opts.Messages, info.Messages...)
|
||||||
|
|
||||||
|
if interactive {
|
||||||
return generateInteractive(cmd, opts)
|
return generateInteractive(cmd, opts)
|
||||||
|
}
|
||||||
|
return generate(cmd, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
func errFromUnknownKey(unknownKeyErr error) error {
|
func errFromUnknownKey(unknownKeyErr error) error {
|
||||||
@ -579,10 +580,6 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(args) != 1 {
|
|
||||||
return errors.New("missing model name")
|
|
||||||
}
|
|
||||||
|
|
||||||
license, errLicense := cmd.Flags().GetBool("license")
|
license, errLicense := cmd.Flags().GetBool("license")
|
||||||
modelfile, errModelfile := cmd.Flags().GetBool("modelfile")
|
modelfile, errModelfile := cmd.Flags().GetBool("modelfile")
|
||||||
parameters, errParams := cmd.Flags().GetBool("parameters")
|
parameters, errParams := cmd.Flags().GetBool("parameters")
|
||||||
@ -625,8 +622,6 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
if flagsSet > 1 {
|
if flagsSet > 1 {
|
||||||
return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
|
return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
|
||||||
} else if flagsSet == 0 {
|
|
||||||
return errors.New("one of '--license', '--modelfile', '--parameters', '--system', or '--template' must be specified")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
req := api.ShowRequest{Name: args[0]}
|
req := api.ShowRequest{Name: args[0]}
|
||||||
@ -635,6 +630,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if flagsSet == 1 {
|
||||||
switch showType {
|
switch showType {
|
||||||
case "license":
|
case "license":
|
||||||
fmt.Println(resp.License)
|
fmt.Println(resp.License)
|
||||||
@ -649,6 +645,124 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
showInfo(resp)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func showInfo(resp *api.ShowResponse) {
|
||||||
|
arch := resp.ModelInfo["general.architecture"].(string)
|
||||||
|
|
||||||
|
modelData := [][]string{
|
||||||
|
{"arch", arch},
|
||||||
|
{"parameters", resp.Details.ParameterSize},
|
||||||
|
{"quantization", resp.Details.QuantizationLevel},
|
||||||
|
{"context length", fmt.Sprintf("%v", resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64))},
|
||||||
|
{"embedding length", fmt.Sprintf("%v", resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)].(float64))},
|
||||||
|
}
|
||||||
|
|
||||||
|
mainTableData := [][]string{
|
||||||
|
{"Model"},
|
||||||
|
{renderSubTable(modelData, false)},
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.ProjectorInfo != nil {
|
||||||
|
projectorData := [][]string{
|
||||||
|
{"arch", "clip"},
|
||||||
|
{"parameters", format.HumanNumber(uint64(resp.ProjectorInfo["general.parameter_count"].(float64)))},
|
||||||
|
}
|
||||||
|
|
||||||
|
if projectorType, ok := resp.ProjectorInfo["clip.projector_type"]; ok {
|
||||||
|
projectorData = append(projectorData, []string{"projector type", projectorType.(string)})
|
||||||
|
}
|
||||||
|
|
||||||
|
projectorData = append(projectorData,
|
||||||
|
[]string{"embedding length", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.embedding_length"].(float64))},
|
||||||
|
[]string{"projection dimensionality", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.projection_dim"].(float64))},
|
||||||
|
)
|
||||||
|
|
||||||
|
mainTableData = append(mainTableData,
|
||||||
|
[]string{"Projector"},
|
||||||
|
[]string{renderSubTable(projectorData, false)},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Parameters != "" {
|
||||||
|
mainTableData = append(mainTableData, []string{"Parameters"}, []string{formatParams(resp.Parameters)})
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.System != "" {
|
||||||
|
mainTableData = append(mainTableData, []string{"System"}, []string{renderSubTable(twoLines(resp.System), true)})
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.License != "" {
|
||||||
|
mainTableData = append(mainTableData, []string{"License"}, []string{renderSubTable(twoLines(resp.License), true)})
|
||||||
|
}
|
||||||
|
|
||||||
|
table := tablewriter.NewWriter(os.Stdout)
|
||||||
|
table.SetAutoWrapText(false)
|
||||||
|
table.SetBorder(false)
|
||||||
|
table.SetAlignment(tablewriter.ALIGN_LEFT)
|
||||||
|
|
||||||
|
for _, v := range mainTableData {
|
||||||
|
table.Append(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
table.Render()
|
||||||
|
}
|
||||||
|
|
||||||
|
func renderSubTable(data [][]string, file bool) string {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
table := tablewriter.NewWriter(&buf)
|
||||||
|
table.SetAutoWrapText(!file)
|
||||||
|
table.SetBorder(false)
|
||||||
|
table.SetNoWhiteSpace(true)
|
||||||
|
table.SetTablePadding("\t")
|
||||||
|
table.SetAlignment(tablewriter.ALIGN_LEFT)
|
||||||
|
|
||||||
|
for _, v := range data {
|
||||||
|
table.Append(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
table.Render()
|
||||||
|
|
||||||
|
renderedTable := buf.String()
|
||||||
|
lines := strings.Split(renderedTable, "\n")
|
||||||
|
for i, line := range lines {
|
||||||
|
lines[i] = "\t" + line
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(lines, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func twoLines(s string) [][]string {
|
||||||
|
lines := strings.Split(s, "\n")
|
||||||
|
res := [][]string{}
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
for _, line := range lines {
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
if line != "" {
|
||||||
|
count++
|
||||||
|
res = append(res, []string{line})
|
||||||
|
if count == 2 {
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatParams(s string) string {
|
||||||
|
lines := strings.Split(s, "\n")
|
||||||
|
table := [][]string{}
|
||||||
|
|
||||||
|
for _, line := range lines {
|
||||||
|
table = append(table, strings.Fields(line))
|
||||||
|
}
|
||||||
|
return renderSubTable(table, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyHandler(cmd *cobra.Command, args []string) error {
|
func CopyHandler(cmd *cobra.Command, args []string) error {
|
||||||
|
@ -31,41 +31,24 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func loadModel(cmd *cobra.Command, opts *runOptions) error {
|
func loadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||||
client, err := api.ClientFromEnvironment()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
p := progress.NewProgress(os.Stderr)
|
p := progress.NewProgress(os.Stderr)
|
||||||
defer p.StopAndClear()
|
defer p.StopAndClear()
|
||||||
|
|
||||||
spinner := progress.NewSpinner("")
|
spinner := progress.NewSpinner("")
|
||||||
p.Add("", spinner)
|
p.Add("", spinner)
|
||||||
|
|
||||||
showReq := api.ShowRequest{Name: opts.Model}
|
client, err := api.ClientFromEnvironment()
|
||||||
showResp, err := client.Show(cmd.Context(), &showReq)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
opts.MultiModal = slices.Contains(showResp.Details.Families, "clip")
|
|
||||||
opts.ParentModel = showResp.Details.ParentModel
|
|
||||||
|
|
||||||
if len(showResp.Messages) > 0 {
|
|
||||||
opts.Messages = append(opts.Messages, showResp.Messages...)
|
|
||||||
}
|
|
||||||
|
|
||||||
chatReq := &api.ChatRequest{
|
chatReq := &api.ChatRequest{
|
||||||
Model: opts.Model,
|
Model: opts.Model,
|
||||||
Messages: []api.Message{},
|
KeepAlive: opts.KeepAlive,
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.KeepAlive != nil {
|
return client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error {
|
||||||
chatReq.KeepAlive = opts.KeepAlive
|
|
||||||
}
|
|
||||||
|
|
||||||
err = client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error {
|
|
||||||
p.StopAndClear()
|
p.StopAndClear()
|
||||||
if len(opts.Messages) > 0 {
|
|
||||||
for _, msg := range opts.Messages {
|
for _, msg := range opts.Messages {
|
||||||
switch msg.Role {
|
switch msg.Role {
|
||||||
case "user":
|
case "user":
|
||||||
@ -77,19 +60,11 @@ func loadModel(cmd *cobra.Command, opts *runOptions) error {
|
|||||||
fmt.Println()
|
fmt.Println()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||||
opts.Messages = make([]api.Message, 0)
|
|
||||||
|
|
||||||
err := loadModel(cmd, &opts)
|
err := loadModel(cmd, &opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -429,15 +404,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
|
|
||||||
switch args[1] {
|
switch args[1] {
|
||||||
case "info":
|
case "info":
|
||||||
fmt.Println("Model details:")
|
showInfo(resp)
|
||||||
if len(resp.Details.Families) > 0 {
|
|
||||||
fmt.Printf("Family %s\n", strings.Join(resp.Details.Families, ", "))
|
|
||||||
} else if resp.Details.Family != "" {
|
|
||||||
fmt.Printf("Family %s\n", resp.Details.Family)
|
|
||||||
}
|
|
||||||
fmt.Printf("Parameter Size %s\n", resp.Details.ParameterSize)
|
|
||||||
fmt.Printf("Quantization Level %s\n", resp.Details.QuantizationLevel)
|
|
||||||
fmt.Println("")
|
|
||||||
case "license":
|
case "license":
|
||||||
if resp.License == "" {
|
if resp.License == "" {
|
||||||
fmt.Println("No license was specified for this model.")
|
fmt.Println("No license was specified for this model.")
|
||||||
|
39
docs/api.md
39
docs/api.md
@ -26,7 +26,7 @@ All durations are returned in nanoseconds.
|
|||||||
|
|
||||||
### Streaming responses
|
### Streaming responses
|
||||||
|
|
||||||
Certain endpoints stream responses as JSON objects and can optional return non-streamed responses.
|
Certain endpoints stream responses as JSON objects. Streaming can be disabled by providing `{"stream": false}` for these endpoints.
|
||||||
|
|
||||||
## Generate a completion
|
## Generate a completion
|
||||||
|
|
||||||
@ -777,11 +777,12 @@ A single JSON object will be returned.
|
|||||||
POST /api/show
|
POST /api/show
|
||||||
```
|
```
|
||||||
|
|
||||||
Show information about a model including details, modelfile, template, parameters, license, and system prompt.
|
Show information about a model including details, modelfile, template, parameters, license, system prompt.
|
||||||
|
|
||||||
### Parameters
|
### Parameters
|
||||||
|
|
||||||
- `name`: name of the model to show
|
- `name`: name of the model to show
|
||||||
|
- `verbose`: (optional) if set to `true`, returns full data for verbose response fields
|
||||||
|
|
||||||
### Examples
|
### Examples
|
||||||
|
|
||||||
@ -798,14 +799,40 @@ curl http://localhost:11434/api/show -d '{
|
|||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"modelfile": "# Modelfile generated by \"ollama show\"\n# To build a new Modelfile based on this one, replace the FROM line with:\n# FROM llava:latest\n\nFROM /Users/matt/.ollama/models/blobs/sha256:200765e1283640ffbd013184bf496e261032fa75b99498a9613be4e94d63ad52\nTEMPLATE \"\"\"{{ .System }}\nUSER: {{ .Prompt }}\nASSISTANT: \"\"\"\nPARAMETER num_ctx 4096\nPARAMETER stop \"\u003c/s\u003e\"\nPARAMETER stop \"USER:\"\nPARAMETER stop \"ASSISTANT:\"",
|
"modelfile": "# Modelfile generated by \"ollama show\"\n# To build a new Modelfile based on this one, replace the FROM line with:\n# FROM llava:latest\n\nFROM /Users/matt/.ollama/models/blobs/sha256:200765e1283640ffbd013184bf496e261032fa75b99498a9613be4e94d63ad52\nTEMPLATE \"\"\"{{ .System }}\nUSER: {{ .Prompt }}\nASSISTANT: \"\"\"\nPARAMETER num_ctx 4096\nPARAMETER stop \"\u003c/s\u003e\"\nPARAMETER stop \"USER:\"\nPARAMETER stop \"ASSISTANT:\"",
|
||||||
"parameters": "num_ctx 4096\nstop \u003c/s\u003e\nstop USER:\nstop ASSISTANT:",
|
"parameters": "num_keep 24\nstop \"<|start_header_id|>\"\nstop \"<|end_header_id|>\"\nstop \"<|eot_id|>\"",
|
||||||
"template": "{{ .System }}\nUSER: {{ .Prompt }}\nASSISTANT: ",
|
"template": "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|>",
|
||||||
"details": {
|
"details": {
|
||||||
|
"parent_model": "",
|
||||||
"format": "gguf",
|
"format": "gguf",
|
||||||
"family": "llama",
|
"family": "llama",
|
||||||
"families": ["llama", "clip"],
|
"families": [
|
||||||
"parameter_size": "7B",
|
"llama"
|
||||||
|
],
|
||||||
|
"parameter_size": "8.0B",
|
||||||
"quantization_level": "Q4_0"
|
"quantization_level": "Q4_0"
|
||||||
|
},
|
||||||
|
"model_info": {
|
||||||
|
"general.architecture": "llama",
|
||||||
|
"general.file_type": 2,
|
||||||
|
"general.parameter_count": 8030261248,
|
||||||
|
"general.quantization_version": 2,
|
||||||
|
"llama.attention.head_count": 32,
|
||||||
|
"llama.attention.head_count_kv": 8,
|
||||||
|
"llama.attention.layer_norm_rms_epsilon": 0.00001,
|
||||||
|
"llama.block_count": 32,
|
||||||
|
"llama.context_length": 8192,
|
||||||
|
"llama.embedding_length": 4096,
|
||||||
|
"llama.feed_forward_length": 14336,
|
||||||
|
"llama.rope.dimension_count": 128,
|
||||||
|
"llama.rope.freq_base": 500000,
|
||||||
|
"llama.vocab_size": 128256,
|
||||||
|
"tokenizer.ggml.bos_token_id": 128000,
|
||||||
|
"tokenizer.ggml.eos_token_id": 128009,
|
||||||
|
"tokenizer.ggml.merges": [], // populates if `verbose=true`
|
||||||
|
"tokenizer.ggml.model": "gpt2",
|
||||||
|
"tokenizer.ggml.pre": "llama-bpe",
|
||||||
|
"tokenizer.ggml.token_type": [], // populates if `verbose=true`
|
||||||
|
"tokenizer.ggml.tokens": [] // populates if `verbose=true`
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
16
docs/faq.md
16
docs/faq.md
@ -257,3 +257,19 @@ If you wish to override the `OLLAMA_KEEP_ALIVE` setting, use the `keep_alive` AP
|
|||||||
## How do I manage the maximum number of requests the Ollama server can queue?
|
## How do I manage the maximum number of requests the Ollama server can queue?
|
||||||
|
|
||||||
If too many requests are sent to the server, it will respond with a 503 error indicating the server is overloaded. You can adjust how many requests may be queue by setting `OLLAMA_MAX_QUEUE`.
|
If too many requests are sent to the server, it will respond with a 503 error indicating the server is overloaded. You can adjust how many requests may be queue by setting `OLLAMA_MAX_QUEUE`.
|
||||||
|
|
||||||
|
## How does Ollama handle concurrent requests?
|
||||||
|
|
||||||
|
Ollama supports two levels of concurrent processing. If your system has sufficient available memory (system memory when using CPU inference, or VRAM for GPU inference) then multiple models can be loaded at the same time. For a given model, if there is sufficient available memory when the model is loaded, it is configured to allow parallel request processing.
|
||||||
|
|
||||||
|
If there is insufficient available memory to load a new model request while one or more models are already loaded, all new requests will be queued until the new model can be loaded. As prior models become idle, one or more will be unloaded to make room for the new model. Queued requests will be processed in order. When using GPU inference new models must be able to completely fit in VRAM to allow concurrent model loads.
|
||||||
|
|
||||||
|
Parallel request processing for a given model results in increasing the context size by the number of parallel requests. For example, a 2K context with 4 parallel requests will result in an 8K context and additional memory allocation.
|
||||||
|
|
||||||
|
The following server settings may be used to adjust how Ollama handles concurrent requests on most platforms:
|
||||||
|
|
||||||
|
- `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory. The default is 3 * the number of GPUs or 3 for CPU inference.
|
||||||
|
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
|
||||||
|
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
|
||||||
|
|
||||||
|
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
|
@ -18,7 +18,7 @@ Check your compute compatibility to see if your card is supported:
|
|||||||
| | Quadro | `RTX 8000` `RTX 6000` `RTX 5000` `RTX 4000` |
|
| | Quadro | `RTX 8000` `RTX 6000` `RTX 5000` `RTX 4000` |
|
||||||
| 7.0 | NVIDIA | `TITAN V` `V100` `Quadro GV100` |
|
| 7.0 | NVIDIA | `TITAN V` `V100` `Quadro GV100` |
|
||||||
| 6.1 | NVIDIA TITAN | `TITAN Xp` `TITAN X` |
|
| 6.1 | NVIDIA TITAN | `TITAN Xp` `TITAN X` |
|
||||||
| | GeForce GTX | `GTX 1080 Ti` `GTX 1080` `GTX 1070 Ti` `GTX 1070` `GTX 1060` `GTX 1050` |
|
| | GeForce GTX | `GTX 1080 Ti` `GTX 1080` `GTX 1070 Ti` `GTX 1070` `GTX 1060` `GTX 1050 Ti` `GTX 1050` |
|
||||||
| | Quadro | `P6000` `P5200` `P4200` `P3200` `P5000` `P4000` `P3000` `P2200` `P2000` `P1000` `P620` `P600` `P500` `P520` |
|
| | Quadro | `P6000` `P5200` `P4200` `P3200` `P5000` `P4000` `P3000` `P2200` `P2000` `P1000` `P620` `P600` `P500` `P520` |
|
||||||
| | Tesla | `P40` `P4` |
|
| | Tesla | `P40` `P4` |
|
||||||
| 6.0 | NVIDIA | `Tesla P100` `Quadro GP100` |
|
| 6.0 | NVIDIA | `Tesla P100` `Quadro GP100` |
|
||||||
|
@ -65,6 +65,7 @@ curl http://localhost:11434/v1/chat/completions \
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
}'
|
}'
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Endpoints
|
## Endpoints
|
||||||
@ -104,7 +105,6 @@ curl http://localhost:11434/v1/chat/completions \
|
|||||||
|
|
||||||
#### Notes
|
#### Notes
|
||||||
|
|
||||||
- `finish_reason` will always be `stop`
|
|
||||||
- `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached
|
- `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached
|
||||||
|
|
||||||
## Models
|
## Models
|
||||||
|
@ -22,7 +22,7 @@ docker logs <container-name>
|
|||||||
If manually running `ollama serve` in a terminal, the logs will be on that terminal.
|
If manually running `ollama serve` in a terminal, the logs will be on that terminal.
|
||||||
|
|
||||||
When you run Ollama on **Windows**, there are a few different locations. You can view them in the explorer window by hitting `<cmd>+R` and type in:
|
When you run Ollama on **Windows**, there are a few different locations. You can view them in the explorer window by hitting `<cmd>+R` and type in:
|
||||||
- `explorer %LOCALAPPDATA%\Ollama` to view logs
|
- `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log`
|
||||||
- `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH)
|
- `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH)
|
||||||
- `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored
|
- `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored
|
||||||
- `explorer %TEMP%` where temporary executable files are stored in one or more `ollama*` directories
|
- `explorer %TEMP%` where temporary executable files are stored in one or more `ollama*` directories
|
||||||
|
@ -39,8 +39,8 @@ server.
|
|||||||
Ollama on Windows stores files in a few different locations. You can view them in
|
Ollama on Windows stores files in a few different locations. You can view them in
|
||||||
the explorer window by hitting `<cmd>+R` and type in:
|
the explorer window by hitting `<cmd>+R` and type in:
|
||||||
- `explorer %LOCALAPPDATA%\Ollama` contains logs, and downloaded updates
|
- `explorer %LOCALAPPDATA%\Ollama` contains logs, and downloaded updates
|
||||||
- *app.log* contains logs from the GUI application
|
- *app.log* contains most resent logs from the GUI application
|
||||||
- *server.log* contains the server logs
|
- *server.log* contains the most recent server logs
|
||||||
- *upgrade.log* contains log output for upgrades
|
- *upgrade.log* contains log output for upgrades
|
||||||
- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
|
- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
|
||||||
- `explorer %HOMEPATH%\.ollama` contains models and configuration
|
- `explorer %HOMEPATH%\.ollama` contains models and configuration
|
||||||
|
@ -85,13 +85,13 @@ func AsMap() map[string]EnvVar {
|
|||||||
"OLLAMA_HOST": {"OLLAMA_HOST", Host, "IP Address for the ollama server (default 127.0.0.1:11434)"},
|
"OLLAMA_HOST": {"OLLAMA_HOST", Host, "IP Address for the ollama server (default 127.0.0.1:11434)"},
|
||||||
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"},
|
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"},
|
||||||
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"},
|
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"},
|
||||||
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models (default 1)"},
|
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"},
|
||||||
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
|
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
|
||||||
"OLLAMA_MAX_VRAM": {"OLLAMA_MAX_VRAM", MaxVRAM, "Maximum VRAM"},
|
"OLLAMA_MAX_VRAM": {"OLLAMA_MAX_VRAM", MaxVRAM, "Maximum VRAM"},
|
||||||
"OLLAMA_MODELS": {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"},
|
"OLLAMA_MODELS": {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"},
|
||||||
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
|
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
|
||||||
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
|
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
|
||||||
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests (default 1)"},
|
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests"},
|
||||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"},
|
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"},
|
||||||
"OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"},
|
"OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"},
|
||||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"},
|
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"},
|
||||||
@ -129,8 +129,8 @@ func clean(key string) string {
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
// default values
|
// default values
|
||||||
NumParallel = 1
|
NumParallel = 0 // Autoselect
|
||||||
MaxRunners = 1
|
MaxRunners = 0 // Autoselect
|
||||||
MaxQueuedRequests = 512
|
MaxQueuedRequests = 512
|
||||||
|
|
||||||
LoadConfig()
|
LoadConfig()
|
||||||
@ -205,8 +205,8 @@ func LoadConfig() {
|
|||||||
|
|
||||||
if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" {
|
if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" {
|
||||||
val, err := strconv.Atoi(onp)
|
val, err := strconv.Atoi(onp)
|
||||||
if err != nil || val <= 0 {
|
if err != nil {
|
||||||
slog.Error("invalid setting must be greater than zero", "OLLAMA_NUM_PARALLEL", onp, "error", err)
|
slog.Error("invalid setting, ignoring", "OLLAMA_NUM_PARALLEL", onp, "error", err)
|
||||||
} else {
|
} else {
|
||||||
NumParallel = val
|
NumParallel = val
|
||||||
}
|
}
|
||||||
@ -251,7 +251,7 @@ func LoadConfig() {
|
|||||||
if maxRunners != "" {
|
if maxRunners != "" {
|
||||||
m, err := strconv.Atoi(maxRunners)
|
m, err := strconv.Atoi(maxRunners)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("invalid setting", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err)
|
slog.Error("invalid setting, ignoring", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err)
|
||||||
} else {
|
} else {
|
||||||
MaxRunners = m
|
MaxRunners = m
|
||||||
}
|
}
|
||||||
@ -260,7 +260,7 @@ func LoadConfig() {
|
|||||||
if onp := os.Getenv("OLLAMA_MAX_QUEUE"); onp != "" {
|
if onp := os.Getenv("OLLAMA_MAX_QUEUE"); onp != "" {
|
||||||
p, err := strconv.Atoi(onp)
|
p, err := strconv.Atoi(onp)
|
||||||
if err != nil || p <= 0 {
|
if err != nil || p <= 0 {
|
||||||
slog.Error("invalid setting", "OLLAMA_MAX_QUEUE", onp, "error", err)
|
slog.Error("invalid setting, ignoring", "OLLAMA_MAX_QUEUE", onp, "error", err)
|
||||||
} else {
|
} else {
|
||||||
MaxQueuedRequests = p
|
MaxQueuedRequests = p
|
||||||
}
|
}
|
||||||
|
@ -115,8 +115,6 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO revisit this once ROCm v6 is available on windows.
|
|
||||||
// v5.7 only reports VRAM used by this process, so it's completely wrong and unusable
|
|
||||||
slog.Debug("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory))
|
slog.Debug("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory))
|
||||||
slog.Debug("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory))
|
slog.Debug("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory))
|
||||||
gpuInfo := RocmGPUInfo{
|
gpuInfo := RocmGPUInfo{
|
||||||
@ -126,6 +124,9 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
|||||||
TotalMemory: totalMemory,
|
TotalMemory: totalMemory,
|
||||||
FreeMemory: freeMemory,
|
FreeMemory: freeMemory,
|
||||||
},
|
},
|
||||||
|
// Free memory reporting on Windows is not reliable until we bump to ROCm v6.2
|
||||||
|
UnreliableFreeMemory: true,
|
||||||
|
|
||||||
ID: strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices
|
ID: strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices
|
||||||
DependencyPath: libDir,
|
DependencyPath: libDir,
|
||||||
MinimumMemory: rocmMinimumMemory,
|
MinimumMemory: rocmMinimumMemory,
|
||||||
|
@ -77,20 +77,27 @@ func cleanupTmpDirs() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
raw, err := os.ReadFile(filepath.Join(d, "ollama.pid"))
|
raw, err := os.ReadFile(filepath.Join(d, "ollama.pid"))
|
||||||
if err == nil {
|
if err != nil {
|
||||||
|
slog.Warn("failed to read ollama.pid", "path", d, "error", err)
|
||||||
|
// No pid, ignore this tmpdir
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
pid, err := strconv.Atoi(string(raw))
|
pid, err := strconv.Atoi(string(raw))
|
||||||
if err == nil {
|
if err != nil {
|
||||||
if proc, err := os.FindProcess(pid); err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
|
slog.Warn("failed to parse pid", "path", d, "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
proc, err := os.FindProcess(pid)
|
||||||
|
if err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
|
||||||
|
slog.Warn("found running ollama", "pid", pid, "path", d)
|
||||||
// Another running ollama, ignore this tmpdir
|
// Another running ollama, ignore this tmpdir
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
|
||||||
} else {
|
if err := os.Remove(d); err != nil {
|
||||||
slog.Debug("failed to open ollama.pid", "path", d, "error", err)
|
slog.Warn("unable to cleanup stale tmpdir", "path", d, "error", err)
|
||||||
}
|
|
||||||
err = os.RemoveAll(d)
|
|
||||||
if err != nil {
|
|
||||||
slog.Debug("unable to cleanup stale tmpdir", "path", d, "error", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
10
gpu/gpu.go
10
gpu/gpu.go
@ -231,7 +231,7 @@ func GetGPUInfo() GpuInfoList {
|
|||||||
// On windows we bundle the nvidia library one level above the runner dir
|
// On windows we bundle the nvidia library one level above the runner dir
|
||||||
depPath := ""
|
depPath := ""
|
||||||
if runtime.GOOS == "windows" && envconfig.RunnersDir != "" {
|
if runtime.GOOS == "windows" && envconfig.RunnersDir != "" {
|
||||||
depPath = filepath.Dir(envconfig.RunnersDir)
|
depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir), "cuda")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load ALL libraries
|
// Load ALL libraries
|
||||||
@ -282,6 +282,12 @@ func GetGPUInfo() GpuInfoList {
|
|||||||
// Intel
|
// Intel
|
||||||
if envconfig.IntelGpu {
|
if envconfig.IntelGpu {
|
||||||
oHandles = initOneAPIHandles()
|
oHandles = initOneAPIHandles()
|
||||||
|
// On windows we bundle the oneapi library one level above the runner dir
|
||||||
|
depPath = ""
|
||||||
|
if runtime.GOOS == "windows" && envconfig.RunnersDir != "" {
|
||||||
|
depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir), "oneapi")
|
||||||
|
}
|
||||||
|
|
||||||
for d := range oHandles.oneapi.num_drivers {
|
for d := range oHandles.oneapi.num_drivers {
|
||||||
if oHandles.oneapi == nil {
|
if oHandles.oneapi == nil {
|
||||||
// shouldn't happen
|
// shouldn't happen
|
||||||
@ -306,7 +312,7 @@ func GetGPUInfo() GpuInfoList {
|
|||||||
gpuInfo.FreeMemory = uint64(memInfo.free)
|
gpuInfo.FreeMemory = uint64(memInfo.free)
|
||||||
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
|
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
|
||||||
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
|
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
|
||||||
// TODO dependency path?
|
gpuInfo.DependencyPath = depPath
|
||||||
oneapiGPUs = append(oneapiGPUs, gpuInfo)
|
oneapiGPUs = append(oneapiGPUs, gpuInfo)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -40,7 +40,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
|||||||
|
|
||||||
for (i = 0; l[i].s != NULL; i++) {
|
for (i = 0; l[i].s != NULL; i++) {
|
||||||
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
|
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
|
||||||
if (!l[i].p) {
|
if (!*(l[i].p)) {
|
||||||
char *msg = LOAD_ERR();
|
char *msg = LOAD_ERR();
|
||||||
LOG(resp->ch.verbose, "dlerr: %s\n", msg);
|
LOG(resp->ch.verbose, "dlerr: %s\n", msg);
|
||||||
UNLOAD_LIBRARY(resp->ch.handle);
|
UNLOAD_LIBRARY(resp->ch.handle);
|
||||||
|
@ -43,7 +43,7 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
|
|||||||
|
|
||||||
for (i = 0; l[i].s != NULL; i++) {
|
for (i = 0; l[i].s != NULL; i++) {
|
||||||
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
|
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
|
||||||
if (!*l[i].p) {
|
if (!*(l[i].p)) {
|
||||||
char *msg = LOAD_ERR();
|
char *msg = LOAD_ERR();
|
||||||
LOG(resp->ch.verbose, "dlerr: %s\n", msg);
|
LOG(resp->ch.verbose, "dlerr: %s\n", msg);
|
||||||
UNLOAD_LIBRARY(resp->ch.handle);
|
UNLOAD_LIBRARY(resp->ch.handle);
|
||||||
|
@ -42,7 +42,7 @@ void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp) {
|
|||||||
// LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
|
// LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
|
||||||
|
|
||||||
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
|
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
|
||||||
if (!l[i].p) {
|
if (!*(l[i].p)) {
|
||||||
resp->ch.handle = NULL;
|
resp->ch.handle = NULL;
|
||||||
char *msg = LOAD_ERR();
|
char *msg = LOAD_ERR();
|
||||||
LOG(resp->ch.verbose, "dlerr: %s\n", msg);
|
LOG(resp->ch.verbose, "dlerr: %s\n", msg);
|
||||||
|
@ -50,7 +50,7 @@ void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp) {
|
|||||||
LOG(resp->oh.verbose, "dlsym: %s\n", l[i].s);
|
LOG(resp->oh.verbose, "dlsym: %s\n", l[i].s);
|
||||||
|
|
||||||
*l[i].p = LOAD_SYMBOL(resp->oh.handle, l[i].s);
|
*l[i].p = LOAD_SYMBOL(resp->oh.handle, l[i].s);
|
||||||
if (!l[i].p) {
|
if (!*(l[i].p)) {
|
||||||
resp->oh.handle = NULL;
|
resp->oh.handle = NULL;
|
||||||
char *msg = LOAD_ERR();
|
char *msg = LOAD_ERR();
|
||||||
LOG(resp->oh.verbose, "dlerr: %s\n", msg);
|
LOG(resp->oh.verbose, "dlerr: %s\n", msg);
|
||||||
|
@ -29,6 +29,11 @@ type GpuInfo struct {
|
|||||||
// Extra environment variables specific to the GPU as list of [key,value]
|
// Extra environment variables specific to the GPU as list of [key,value]
|
||||||
EnvWorkarounds [][2]string `json:"envs,omitempty"`
|
EnvWorkarounds [][2]string `json:"envs,omitempty"`
|
||||||
|
|
||||||
|
// Set to true if we can NOT reliably discover FreeMemory. A value of true indicates
|
||||||
|
// the FreeMemory is best effort, and may over or under report actual memory usage
|
||||||
|
// False indicates FreeMemory can generally be trusted on this GPU
|
||||||
|
UnreliableFreeMemory bool
|
||||||
|
|
||||||
// GPU information
|
// GPU information
|
||||||
ID string `json:"gpu_id"` // string to use for selection of this specific GPU
|
ID string `json:"gpu_id"` // string to use for selection of this specific GPU
|
||||||
Name string `json:"name"` // user friendly name if available
|
Name string `json:"name"` // user friendly name if available
|
||||||
|
57
llm/ext_server/server.cpp
vendored
57
llm/ext_server/server.cpp
vendored
@ -56,7 +56,6 @@ struct server_params {
|
|||||||
std::string hostname = "127.0.0.1";
|
std::string hostname = "127.0.0.1";
|
||||||
std::vector<std::string> api_keys;
|
std::vector<std::string> api_keys;
|
||||||
std::string public_path = "examples/server/public";
|
std::string public_path = "examples/server/public";
|
||||||
std::string chat_template = "";
|
|
||||||
int32_t port = 8080;
|
int32_t port = 8080;
|
||||||
int32_t read_timeout = 600;
|
int32_t read_timeout = 600;
|
||||||
int32_t write_timeout = 600;
|
int32_t write_timeout = 600;
|
||||||
@ -427,16 +426,6 @@ struct llama_server_context
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void validate_model_chat_template(server_params & sparams) {
|
|
||||||
llama_chat_message chat[] = {{"user", "test"}};
|
|
||||||
std::vector<char> buf(1);
|
|
||||||
int res = llama_chat_apply_template(model, nullptr, chat, 1, true, buf.data(), buf.size());
|
|
||||||
if (res < 0) {
|
|
||||||
LOG_ERROR("The chat template comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
|
|
||||||
sparams.chat_template = "chatml";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void initialize() {
|
void initialize() {
|
||||||
// create slots
|
// create slots
|
||||||
all_slots_are_idle = true;
|
all_slots_are_idle = true;
|
||||||
@ -1661,26 +1650,41 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
||||||
|
|
||||||
|
char buf[256];
|
||||||
|
llama_model_meta_val_str(model, "general.architecture", buf, 256);
|
||||||
|
bool gemma2 = strcmp(buf, "gemma2") == 0;
|
||||||
|
|
||||||
|
int32_t truncate_at = slot.n_ctx;
|
||||||
|
|
||||||
|
// truncate at 2/3 of the context length for gemma2 models
|
||||||
|
// as they do not support context shifts (from the sliding window implementation).
|
||||||
|
// this way, prompts that almost fit the context length can still generate a full
|
||||||
|
// response without a sudden stop from hitting the context limit
|
||||||
|
if (gemma2) {
|
||||||
|
truncate_at = 2 * slot.n_ctx / 3;
|
||||||
|
}
|
||||||
|
|
||||||
// if input prompt is too big, truncate it, if group attention self-extend is disabled
|
// if input prompt is too big, truncate it, if group attention self-extend is disabled
|
||||||
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx)
|
if (slot.ga_n == 1 && slot.n_prompt_tokens >= truncate_at)
|
||||||
{
|
{
|
||||||
const int n_left = slot.n_ctx - slot.params.n_keep;
|
const int n_left = slot.n_ctx - slot.params.n_keep;
|
||||||
const int n_block_size = n_left / 2;
|
const int n_shift = n_left / 2;
|
||||||
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
const int n_erase = slot.n_prompt_tokens - slot.params.n_keep - n_shift;
|
||||||
|
|
||||||
std::vector<llama_token> new_tokens(
|
std::vector<llama_token> new_tokens(
|
||||||
prompt_tokens.begin(),
|
prompt_tokens.begin(),
|
||||||
prompt_tokens.begin() + slot.params.n_keep);
|
prompt_tokens.begin() + slot.params.n_keep);
|
||||||
new_tokens.insert(
|
new_tokens.insert(
|
||||||
new_tokens.end(),
|
new_tokens.end(),
|
||||||
prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
|
prompt_tokens.begin() + slot.params.n_keep + n_erase,
|
||||||
prompt_tokens.end());
|
prompt_tokens.end());
|
||||||
|
|
||||||
LOG_VERBOSE("input truncated", {
|
LOG_INFO("input truncated", {
|
||||||
{"n_ctx", slot.n_ctx},
|
{"n_ctx", slot.n_ctx},
|
||||||
{"n_keep", slot.params.n_keep},
|
{"n_keep", slot.params.n_keep},
|
||||||
{"n_left", n_left},
|
{"n_left", n_left},
|
||||||
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
{"n_shift", n_shift},
|
||||||
|
{"n_erase", n_erase},
|
||||||
});
|
});
|
||||||
slot.truncated = true;
|
slot.truncated = true;
|
||||||
prompt_tokens = new_tokens;
|
prompt_tokens = new_tokens;
|
||||||
@ -1689,6 +1693,19 @@ struct llama_server_context
|
|||||||
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Models with sliding window attention do not work with context shifts, so
|
||||||
|
// limit their prediction to the context length
|
||||||
|
if (gemma2) {
|
||||||
|
int32_t limit = slot.n_ctx - slot.n_prompt_tokens;
|
||||||
|
slot.n_predict = limit;
|
||||||
|
slot.params.n_predict = limit;
|
||||||
|
LOG_INFO("model does not support sliding window, limiting generation", {
|
||||||
|
{"n_ctx", slot.n_ctx},
|
||||||
|
{"n_prompt_tokens", slot.n_prompt_tokens},
|
||||||
|
{"n_predict", slot.n_predict}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
if (!slot.params.cache_prompt)
|
if (!slot.params.cache_prompt)
|
||||||
{
|
{
|
||||||
llama_sampling_reset(slot.ctx_sampling);
|
llama_sampling_reset(slot.ctx_sampling);
|
||||||
@ -2535,7 +2552,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, g
|
|||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
sparams.chat_template = argv[i];
|
|
||||||
}
|
}
|
||||||
else if (arg == "--override-kv")
|
else if (arg == "--override-kv")
|
||||||
{
|
{
|
||||||
@ -3008,11 +3024,6 @@ int main(int argc, char **argv) {
|
|||||||
}
|
}
|
||||||
const auto model_meta = llama.model_meta();
|
const auto model_meta = llama.model_meta();
|
||||||
|
|
||||||
if (sparams.chat_template.empty()) { // custom chat template is not supplied
|
|
||||||
// check if the template comes with the model is supported by us
|
|
||||||
llama.validate_model_chat_template(sparams);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Middleware for API key validation
|
// Middleware for API key validation
|
||||||
auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
|
auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
|
||||||
// If API key is not set, skip validation
|
// If API key is not set, skip validation
|
||||||
|
@ -295,10 +295,12 @@ function build_cuda() {
|
|||||||
sign
|
sign
|
||||||
install
|
install
|
||||||
|
|
||||||
write-host "copying CUDA dependencies to ${script:SRC_DIR}\dist\windows-${script:ARCH}\"
|
rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\cuda\"
|
||||||
cp "${script:CUDA_LIB_DIR}\cudart64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
|
md "${script:SRC_DIR}\dist\windows-${script:ARCH}\cuda\" -ea 0 > $null
|
||||||
cp "${script:CUDA_LIB_DIR}\cublas64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
|
write-host "copying CUDA dependencies to ${script:SRC_DIR}\dist\windows-${script:ARCH}\cuda\"
|
||||||
cp "${script:CUDA_LIB_DIR}\cublasLt64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
|
cp "${script:CUDA_LIB_DIR}\cudart64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\cuda\"
|
||||||
|
cp "${script:CUDA_LIB_DIR}\cublas64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\cuda\"
|
||||||
|
cp "${script:CUDA_LIB_DIR}\cublasLt64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\cuda\"
|
||||||
} else {
|
} else {
|
||||||
write-host "Skipping CUDA generation step"
|
write-host "Skipping CUDA generation step"
|
||||||
}
|
}
|
||||||
@ -332,16 +334,18 @@ function build_oneapi() {
|
|||||||
sign
|
sign
|
||||||
install
|
install
|
||||||
|
|
||||||
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\libirngmd.dll" "${script:distDir}"
|
rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\"
|
||||||
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\libmmd.dll" "${script:distDir}"
|
md "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\" -ea 0 > $null
|
||||||
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_level_zero.dll" "${script:distDir}"
|
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\libirngmd.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\"
|
||||||
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_unified_runtime.dll" "${script:distDir}"
|
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\libmmd.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\"
|
||||||
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_win_proxy_loader.dll" "${script:distDir}"
|
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_level_zero.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\"
|
||||||
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\svml_dispmd.dll" "${script:distDir}"
|
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_unified_runtime.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\"
|
||||||
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\sycl7.dll" "${script:distDir}"
|
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_win_proxy_loader.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\"
|
||||||
cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_core.2.dll" "${script:distDir}"
|
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\svml_dispmd.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\"
|
||||||
cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_sycl_blas.4.dll" "${script:distDir}"
|
cp "${env:ONEAPI_ROOT}\compiler\latest\bin\sycl7.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\"
|
||||||
cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_tbb_thread.2.dll" "${script:distDir}"
|
cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_core.2.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\"
|
||||||
|
cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_sycl_blas.4.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\"
|
||||||
|
cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_tbb_thread.2.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\"
|
||||||
} else {
|
} else {
|
||||||
Write-Host "Skipping oneAPI generation step"
|
Write-Host "Skipping oneAPI generation step"
|
||||||
}
|
}
|
||||||
|
13
llm/ggla.go
13
llm/ggla.go
@ -53,7 +53,7 @@ func (llm *ggla) Tensors() Tensors {
|
|||||||
return llm.tensors
|
return llm.tensors
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *ggla) decode(rs io.ReadSeeker) error {
|
func (llm *ggla) decode(rs io.ReadSeeker) (retErr error) {
|
||||||
var r uint32
|
var r uint32
|
||||||
if err := binary.Read(rs, binary.LittleEndian, &r); err != nil {
|
if err := binary.Read(rs, binary.LittleEndian, &r); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -69,9 +69,18 @@ func (llm *ggla) decode(rs io.ReadSeeker) error {
|
|||||||
for {
|
for {
|
||||||
var dims uint32
|
var dims uint32
|
||||||
if err := binary.Read(rs, binary.LittleEndian, &dims); err != nil {
|
if err := binary.Read(rs, binary.LittleEndian, &dims); err != nil {
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if errors.Is(retErr, io.EOF) {
|
||||||
|
retErr = io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
var namesize uint32
|
var namesize uint32
|
||||||
if err := binary.Read(rs, binary.LittleEndian, &namesize); err != nil {
|
if err := binary.Read(rs, binary.LittleEndian, &namesize); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -108,7 +117,7 @@ func (llm *ggla) decode(rs io.ReadSeeker) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := rs.Seek((offset+31)&-32, io.SeekStart); err != nil {
|
if _, err := rs.Seek((offset+31)&-32-offset, io.SeekCurrent); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
85
llm/ggml.go
85
llm/ggml.go
@ -6,6 +6,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/util/bufioutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GGML struct {
|
type GGML struct {
|
||||||
@ -69,6 +71,30 @@ func (kv KV) HeadCountKV() uint64 {
|
|||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (kv KV) EmbeddingHeadCount() uint64 {
|
||||||
|
if heads := kv.HeadCount(); heads > 0 {
|
||||||
|
return kv.EmbeddingLength() / kv.HeadCount()
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) EmbeddingHeadCountK() uint64 {
|
||||||
|
if k := kv.u64(fmt.Sprintf("%s.attention.key_length", kv.Architecture())); k > 0 {
|
||||||
|
return k
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv.EmbeddingHeadCount()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) EmbeddingHeadCountV() uint64 {
|
||||||
|
if v := kv.u64(fmt.Sprintf("%s.attention.value_length", kv.Architecture())); v > 0 {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv.EmbeddingHeadCount()
|
||||||
|
}
|
||||||
|
|
||||||
func (kv KV) GQA() uint64 {
|
func (kv KV) GQA() uint64 {
|
||||||
return kv.HeadCount() / kv.HeadCountKV()
|
return kv.HeadCount() / kv.HeadCountKV()
|
||||||
}
|
}
|
||||||
@ -254,7 +280,18 @@ func DetectGGMLType(b []byte) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
|
// DecodeGGML decodes a GGML model from the given reader.
|
||||||
|
//
|
||||||
|
// It collects array values for arrays with a size less than or equal to
|
||||||
|
// maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
|
||||||
|
// the maxArraySize is negative, all arrays are collected.
|
||||||
|
func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
||||||
|
if maxArraySize == 0 {
|
||||||
|
maxArraySize = 1024
|
||||||
|
}
|
||||||
|
|
||||||
|
rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
|
||||||
|
|
||||||
var magic uint32
|
var magic uint32
|
||||||
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
|
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
@ -267,17 +304,15 @@ func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
|
|||||||
case FILE_MAGIC_GGLA:
|
case FILE_MAGIC_GGLA:
|
||||||
c = &containerGGLA{}
|
c = &containerGGLA{}
|
||||||
case FILE_MAGIC_GGUF_LE:
|
case FILE_MAGIC_GGUF_LE:
|
||||||
c = &containerGGUF{ByteOrder: binary.LittleEndian}
|
c = &containerGGUF{ByteOrder: binary.LittleEndian, maxArraySize: maxArraySize}
|
||||||
case FILE_MAGIC_GGUF_BE:
|
case FILE_MAGIC_GGUF_BE:
|
||||||
c = &containerGGUF{ByteOrder: binary.BigEndian}
|
c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize}
|
||||||
default:
|
default:
|
||||||
return nil, 0, errors.New("invalid file magic")
|
return nil, 0, errors.New("invalid file magic")
|
||||||
}
|
}
|
||||||
|
|
||||||
model, err := c.Decode(rs)
|
model, err := c.Decode(rs)
|
||||||
if errors.Is(err, io.EOF) {
|
if err != nil {
|
||||||
// noop
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -297,7 +332,10 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
|||||||
embedding := llm.KV().EmbeddingLength()
|
embedding := llm.KV().EmbeddingLength()
|
||||||
heads := llm.KV().HeadCount()
|
heads := llm.KV().HeadCount()
|
||||||
headsKV := llm.KV().HeadCountKV()
|
headsKV := llm.KV().HeadCountKV()
|
||||||
vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any)))
|
vocab := uint64(llm.KV()["tokenizer.ggml.tokens"].(*array).size)
|
||||||
|
|
||||||
|
embeddingHeads := llm.KV().EmbeddingHeadCount()
|
||||||
|
embeddingHeadsK := llm.KV().EmbeddingHeadCountK()
|
||||||
|
|
||||||
layers := llm.Tensors().Layers()
|
layers := llm.Tensors().Layers()
|
||||||
|
|
||||||
@ -308,7 +346,7 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
|||||||
partialOffload = 4 * batch * embedding
|
partialOffload = 4 * batch * embedding
|
||||||
partialOffload += max(
|
partialOffload += max(
|
||||||
// 4*batch*(4+6*embedding+context*(2*heads)+llm.KV().GQA()),
|
// 4*batch*(4+6*embedding+context*(2*heads)+llm.KV().GQA()),
|
||||||
4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV),
|
4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embeddingHeads*headsKV),
|
||||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -316,21 +354,30 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
|||||||
// mixtral 8x22b
|
// mixtral 8x22b
|
||||||
ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
|
ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
|
||||||
partialOffload = max(
|
partialOffload = max(
|
||||||
3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embedding/heads*headsKV),
|
3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
|
||||||
4*(context*batch*heads+context*embedding/heads*headsKV+batch*1024+embedding/heads*headsKV*batch),
|
4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
|
||||||
)
|
)
|
||||||
} else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
|
} else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
|
||||||
// mixtral 8x7b
|
// mixtral 8x7b
|
||||||
ffnGateWeight1 := ffnGateWeight.Shape[1]
|
ffnGateWeight1 := ffnGateWeight.Shape[1]
|
||||||
fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
|
fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
|
||||||
partialOffload = max(
|
partialOffload = max(
|
||||||
4*batch*(3+embedding/heads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
|
4*batch*(3+embeddingHeads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
|
||||||
4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
|
4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
case "gemma":
|
case "gemma", "gemma2":
|
||||||
fullOffload = 4 * batch * (embedding + vocab)
|
fullOffload = max(
|
||||||
partialOffload = 4*batch*(2*embedding+vocab+1) + embedding*vocab*105/128
|
4*batch*(embedding+vocab),
|
||||||
|
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
|
||||||
|
)
|
||||||
|
|
||||||
|
partialOffload = max(
|
||||||
|
4*embedding*batch+embedding*vocab*105/128+4*vocab*batch,
|
||||||
|
4*batch*(2*embedding+1+2*embeddingHeadsK*heads+context+context*heads)+
|
||||||
|
4*embeddingHeadsK*context*8+
|
||||||
|
embedding*embeddingHeadsK*heads*9/16,
|
||||||
|
)
|
||||||
case "command-r":
|
case "command-r":
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
4*batch*(embedding+vocab),
|
4*batch*(embedding+vocab),
|
||||||
@ -367,6 +414,16 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
|||||||
4*batch*(vocab+2*embedding),
|
4*batch*(vocab+2*embedding),
|
||||||
fullOffload,
|
fullOffload,
|
||||||
)
|
)
|
||||||
|
case "deepseek2":
|
||||||
|
fullOffload = max(
|
||||||
|
4*batch*(3*embedding+vocab),
|
||||||
|
4*batch*(3*embedding+2+context*(1+headsKV)+2*embeddingHeadsK*headsKV),
|
||||||
|
)
|
||||||
|
|
||||||
|
partialOffload = max(
|
||||||
|
4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
|
||||||
|
4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
|
1
llm/ggml_test.go
Normal file
1
llm/ggml_test.go
Normal file
@ -0,0 +1 @@
|
|||||||
|
package llm
|
118
llm/gguf.go
118
llm/gguf.go
@ -3,11 +3,10 @@ package llm
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"log/slog"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type containerGGUF struct {
|
type containerGGUF struct {
|
||||||
@ -29,6 +28,12 @@ type containerGGUF struct {
|
|||||||
NumTensor uint64
|
NumTensor uint64
|
||||||
NumKV uint64
|
NumKV uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
maxArraySize int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *containerGGUF) canCollectArray(size int) bool {
|
||||||
|
return c.maxArraySize < 0 || size <= c.maxArraySize
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *containerGGUF) Name() string {
|
func (c *containerGGUF) Name() string {
|
||||||
@ -54,7 +59,6 @@ func (c *containerGGUF) Decode(rs io.ReadSeeker) (model, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
model := newGGUF(c)
|
model := newGGUF(c)
|
||||||
slog.Debug(fmt.Sprintf("model = %#v", model))
|
|
||||||
if err := model.Decode(rs); err != nil {
|
if err := model.Decode(rs); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -85,6 +89,8 @@ type gguf struct {
|
|||||||
tensors []*Tensor
|
tensors []*Tensor
|
||||||
|
|
||||||
parameters uint64
|
parameters uint64
|
||||||
|
|
||||||
|
scratch [16 << 10]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func newGGUF(container *containerGGUF) *gguf {
|
func newGGUF(container *containerGGUF) *gguf {
|
||||||
@ -181,34 +187,34 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// decode tensors
|
// decode tensors
|
||||||
for i := 0; uint64(i) < llm.numTensor(); i++ {
|
for range llm.numTensor() {
|
||||||
name, err := readGGUFString(llm, rs)
|
name, err := readGGUFString(llm, rs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to read tensor name: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// dims is the number of dimensions in the tensor
|
// dims is the number of dimensions in the tensor
|
||||||
dims, err := readGGUF[uint32](llm, rs)
|
dims, err := readGGUF[uint32](llm, rs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to read tensor dimensions: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
shape := [4]uint64{1, 1, 1, 1}
|
shape := [4]uint64{1, 1, 1, 1}
|
||||||
for i := 0; uint32(i) < dims; i++ {
|
for i := 0; uint32(i) < dims; i++ {
|
||||||
shape[i], err = readGGUF[uint64](llm, rs)
|
shape[i], err = readGGUF[uint64](llm, rs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to read tensor shape: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
kind, err := readGGUF[uint32](llm, rs)
|
kind, err := readGGUF[uint32](llm, rs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to read tensor kind: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
offset, err := readGGUF[uint64](llm, rs)
|
offset, err := readGGUF[uint64](llm, rs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to read tensor offset: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tensor := Tensor{
|
tensor := Tensor{
|
||||||
@ -230,24 +236,19 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
|
|||||||
alignment = 32
|
alignment = 32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, tensor := range llm.tensors {
|
||||||
offset, err := rs.Seek(0, io.SeekCurrent)
|
offset, err := rs.Seek(0, io.SeekCurrent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to get current offset: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
padding := llm.padding(offset, int64(alignment))
|
padding := llm.padding(offset, int64(alignment))
|
||||||
if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
|
if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to seek to init padding: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tensor := range llm.tensors {
|
|
||||||
if _, err := rs.Seek(int64(tensor.Size()), io.SeekCurrent); err != nil {
|
if _, err := rs.Seek(int64(tensor.Size()), io.SeekCurrent); err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to seek to tensor: %w", err)
|
||||||
}
|
|
||||||
|
|
||||||
padding := llm.padding(int64(tensor.Size()), int64(alignment))
|
|
||||||
if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -285,22 +286,48 @@ func readGGUFV1String(llm *gguf, r io.Reader) (string, error) {
|
|||||||
return b.String(), nil
|
return b.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func discardGGUFString(llm *gguf, r io.Reader) error {
|
||||||
|
buf := llm.scratch[:8]
|
||||||
|
_, err := io.ReadFull(r, buf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
size := int(llm.ByteOrder.Uint64(buf))
|
||||||
|
for size > 0 {
|
||||||
|
n, err := r.Read(llm.scratch[:min(size, cap(llm.scratch))])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
size -= n
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func readGGUFString(llm *gguf, r io.Reader) (string, error) {
|
func readGGUFString(llm *gguf, r io.Reader) (string, error) {
|
||||||
if llm.Version == 1 {
|
if llm.Version == 1 {
|
||||||
return readGGUFV1String(llm, r)
|
return readGGUFV1String(llm, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
var length uint64
|
buf := llm.scratch[:8]
|
||||||
if err := binary.Read(r, llm.ByteOrder, &length); err != nil {
|
_, err := io.ReadFull(r, buf)
|
||||||
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
var b bytes.Buffer
|
length := int(llm.ByteOrder.Uint64(buf))
|
||||||
if _, err := io.CopyN(&b, r, int64(length)); err != nil {
|
if length > len(llm.scratch) {
|
||||||
|
buf = make([]byte, length)
|
||||||
|
} else {
|
||||||
|
buf = llm.scratch[:length]
|
||||||
|
}
|
||||||
|
clear(buf)
|
||||||
|
|
||||||
|
_, err = io.ReadFull(r, buf)
|
||||||
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
return string(buf), nil
|
||||||
return b.String(), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeGGUFString(llm *gguf, w io.Writer, s string) error {
|
func writeGGUFString(llm *gguf, w io.Writer, s string) error {
|
||||||
@ -316,7 +343,16 @@ func writeGGUFString(llm *gguf, w io.Writer, s string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func readGGUFV1Array(llm *gguf, r io.Reader) (a []any, err error) {
|
type array struct {
|
||||||
|
size int
|
||||||
|
values []any
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *array) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(a.values)
|
||||||
|
}
|
||||||
|
|
||||||
|
func readGGUFV1Array(llm *gguf, r io.Reader) (*array, error) {
|
||||||
t, err := readGGUF[uint32](llm, r)
|
t, err := readGGUF[uint32](llm, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -327,7 +363,12 @@ func readGGUFV1Array(llm *gguf, r io.Reader) (a []any, err error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; uint32(i) < n; i++ {
|
a := &array{size: int(n)}
|
||||||
|
if llm.canCollectArray(int(n)) {
|
||||||
|
a.values = make([]any, 0, int(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range n {
|
||||||
var e any
|
var e any
|
||||||
switch t {
|
switch t {
|
||||||
case ggufTypeUint8:
|
case ggufTypeUint8:
|
||||||
@ -361,13 +402,15 @@ func readGGUFV1Array(llm *gguf, r io.Reader) (a []any, err error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
a = append(a, e)
|
if a.values != nil {
|
||||||
|
a.values[i] = e
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return a, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func readGGUFArray(llm *gguf, r io.Reader) (a []any, err error) {
|
func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
|
||||||
if llm.Version == 1 {
|
if llm.Version == 1 {
|
||||||
return readGGUFV1Array(llm, r)
|
return readGGUFV1Array(llm, r)
|
||||||
}
|
}
|
||||||
@ -382,7 +425,12 @@ func readGGUFArray(llm *gguf, r io.Reader) (a []any, err error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; uint64(i) < n; i++ {
|
a := &array{size: int(n)}
|
||||||
|
if llm.canCollectArray(int(n)) {
|
||||||
|
a.values = make([]any, int(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range n {
|
||||||
var e any
|
var e any
|
||||||
switch t {
|
switch t {
|
||||||
case ggufTypeUint8:
|
case ggufTypeUint8:
|
||||||
@ -408,7 +456,11 @@ func readGGUFArray(llm *gguf, r io.Reader) (a []any, err error) {
|
|||||||
case ggufTypeBool:
|
case ggufTypeBool:
|
||||||
e, err = readGGUF[bool](llm, r)
|
e, err = readGGUF[bool](llm, r)
|
||||||
case ggufTypeString:
|
case ggufTypeString:
|
||||||
|
if a.values != nil {
|
||||||
e, err = readGGUFString(llm, r)
|
e, err = readGGUFString(llm, r)
|
||||||
|
} else {
|
||||||
|
err = discardGGUFString(llm, r)
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid array type: %d", t)
|
return nil, fmt.Errorf("invalid array type: %d", t)
|
||||||
}
|
}
|
||||||
@ -416,10 +468,12 @@ func readGGUFArray(llm *gguf, r io.Reader) (a []any, err error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
a = append(a, e)
|
if a.values != nil {
|
||||||
|
a.values[i] = e
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return a, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeGGUFArray[S ~[]E, E any](llm *gguf, w io.Writer, t uint32, s S) error {
|
func writeGGUFArray[S ~[]E, E any](llm *gguf, w io.Writer, t uint32, s S) error {
|
||||||
|
@ -115,8 +115,8 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts
|
|||||||
slog.Warn("model missing blk.0 layer size")
|
slog.Warn("model missing blk.0 layer size")
|
||||||
}
|
}
|
||||||
|
|
||||||
// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
|
// fp16 k,v = sizeof(float16) * n_ctx * n_layer * (n_embd_head_k + n_embd_head_v) * n_head_kv
|
||||||
var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
|
var kv uint64 = 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * (ggml.KV().EmbeddingHeadCountK() + ggml.KV().EmbeddingHeadCountV()) * ggml.KV().HeadCountKV()
|
||||||
|
|
||||||
// KV is proportional to the number of layers
|
// KV is proportional to the number of layers
|
||||||
layerSize += kv / ggml.KV().BlockCount()
|
layerSize += kv / ggml.KV().BlockCount()
|
||||||
|
@ -22,13 +22,14 @@ func TestEstimateGPULayers(t *testing.T) {
|
|||||||
defer f.Close()
|
defer f.Close()
|
||||||
gguf := NewGGUFV3(binary.LittleEndian)
|
gguf := NewGGUFV3(binary.LittleEndian)
|
||||||
inputLayerCount := 5
|
inputLayerCount := 5
|
||||||
|
|
||||||
tensors := []Tensor{
|
tensors := []Tensor{
|
||||||
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
|
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
||||||
{Name: "blk.1.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
|
{Name: "blk.1.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
||||||
{Name: "blk.2.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
|
{Name: "blk.2.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
||||||
{Name: "blk.3.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
|
{Name: "blk.3.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
||||||
{Name: "blk.4.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
|
{Name: "blk.4.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
||||||
{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
|
{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
||||||
}
|
}
|
||||||
assert.Len(t, tensors, inputLayerCount+1)
|
assert.Len(t, tensors, inputLayerCount+1)
|
||||||
err = gguf.Encode(f, KV{
|
err = gguf.Encode(f, KV{
|
||||||
@ -45,8 +46,10 @@ func TestEstimateGPULayers(t *testing.T) {
|
|||||||
}, tensors)
|
}, tensors)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ggml, err := LoadModel(f.Name())
|
ggml, err := LoadModel(f.Name(), 0)
|
||||||
require.NoError(t, err)
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
// Simple CPU scenario
|
// Simple CPU scenario
|
||||||
gpus := []gpu.GpuInfo{
|
gpus := []gpu.GpuInfo{
|
||||||
|
305
llm/patches/07-gemma.diff
Normal file
305
llm/patches/07-gemma.diff
Normal file
@ -0,0 +1,305 @@
|
|||||||
|
From 5cadb45f39d001ffbad95b690d6cf0abcb4a6d96 Mon Sep 17 00:00:00 2001
|
||||||
|
From: Ollama maintainers <hello@ollama.com>
|
||||||
|
Date: Wed, 26 Jun 2024 16:18:09 -0700
|
||||||
|
Subject: [PATCH] Architecture support
|
||||||
|
|
||||||
|
---
|
||||||
|
llama.cpp | 194 +++++++++++++++++++++++++++++++++++++++++++++++++++++-
|
||||||
|
1 file changed, 193 insertions(+), 1 deletion(-)
|
||||||
|
|
||||||
|
diff --git a/llama.cpp b/llama.cpp
|
||||||
|
index 61948751..3b4196f5 100644
|
||||||
|
--- a/llama.cpp
|
||||||
|
+++ b/llama.cpp
|
||||||
|
@@ -217,6 +217,7 @@ enum llm_arch {
|
||||||
|
LLM_ARCH_INTERNLM2,
|
||||||
|
LLM_ARCH_MINICPM,
|
||||||
|
LLM_ARCH_GEMMA,
|
||||||
|
+ LLM_ARCH_GEMMA2,
|
||||||
|
LLM_ARCH_STARCODER2,
|
||||||
|
LLM_ARCH_MAMBA,
|
||||||
|
LLM_ARCH_XVERSE,
|
||||||
|
@@ -255,6 +256,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||||
|
{ LLM_ARCH_INTERNLM2, "internlm2" },
|
||||||
|
{ LLM_ARCH_MINICPM, "minicpm" },
|
||||||
|
{ LLM_ARCH_GEMMA, "gemma" },
|
||||||
|
+ { LLM_ARCH_GEMMA2, "gemma2" },
|
||||||
|
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
||||||
|
{ LLM_ARCH_MAMBA, "mamba" },
|
||||||
|
{ LLM_ARCH_XVERSE, "xverse" },
|
||||||
|
@@ -464,10 +466,12 @@ enum llm_tensor {
|
||||||
|
LLM_TENSOR_ATTN_NORM,
|
||||||
|
LLM_TENSOR_ATTN_NORM_2,
|
||||||
|
LLM_TENSOR_ATTN_OUT_NORM,
|
||||||
|
+ LLM_TENSOR_ATTN_POST_NORM,
|
||||||
|
LLM_TENSOR_ATTN_ROT_EMBD,
|
||||||
|
LLM_TENSOR_FFN_GATE_INP,
|
||||||
|
LLM_TENSOR_FFN_GATE_INP_SHEXP,
|
||||||
|
LLM_TENSOR_FFN_NORM,
|
||||||
|
+ LLM_TENSOR_FFN_POST_NORM,
|
||||||
|
LLM_TENSOR_FFN_GATE,
|
||||||
|
LLM_TENSOR_FFN_DOWN,
|
||||||
|
LLM_TENSOR_FFN_UP,
|
||||||
|
@@ -960,6 +964,24 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
+ {
|
||||||
|
+ LLM_ARCH_GEMMA2,
|
||||||
|
+ {
|
||||||
|
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
|
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
|
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||||
|
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||||
|
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||||
|
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
+ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||||
|
+ },
|
||||||
|
+ },
|
||||||
|
{
|
||||||
|
LLM_ARCH_STARCODER2,
|
||||||
|
{
|
||||||
|
@@ -1941,6 +1963,8 @@ enum e_model {
|
||||||
|
MODEL_8x22B,
|
||||||
|
MODEL_16x12B,
|
||||||
|
MODEL_10B_128x3_66B,
|
||||||
|
+ MODEL_9B,
|
||||||
|
+ MODEL_27B,
|
||||||
|
};
|
||||||
|
|
||||||
|
static const size_t kiB = 1024;
|
||||||
|
@@ -2114,6 +2138,7 @@ struct llama_layer {
|
||||||
|
struct ggml_tensor * attn_out_norm_b;
|
||||||
|
struct ggml_tensor * attn_q_a_norm;
|
||||||
|
struct ggml_tensor * attn_kv_a_norm;
|
||||||
|
+ struct ggml_tensor * attn_post_norm;
|
||||||
|
|
||||||
|
// attention
|
||||||
|
struct ggml_tensor * wq;
|
||||||
|
@@ -2136,6 +2161,7 @@ struct llama_layer {
|
||||||
|
// normalization
|
||||||
|
struct ggml_tensor * ffn_norm;
|
||||||
|
struct ggml_tensor * ffn_norm_b;
|
||||||
|
+ struct ggml_tensor * ffn_post_norm;
|
||||||
|
struct ggml_tensor * layer_out_norm;
|
||||||
|
struct ggml_tensor * layer_out_norm_b;
|
||||||
|
struct ggml_tensor * ffn_norm_exps;
|
||||||
|
@@ -4529,6 +4555,16 @@ static void llm_load_hparams(
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case LLM_ARCH_GEMMA:
|
||||||
|
+ {
|
||||||
|
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
+
|
||||||
|
+ switch (hparams.n_layer) {
|
||||||
|
+ case 18: model.type = e_model::MODEL_9B; break;
|
||||||
|
+ case 28: model.type = e_model::MODEL_27B; break;
|
||||||
|
+ default: model.type = e_model::MODEL_UNKNOWN;
|
||||||
|
+ }
|
||||||
|
+ } break;
|
||||||
|
+ case LLM_ARCH_GEMMA2:
|
||||||
|
{
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
|
||||||
|
@@ -6305,6 +6341,40 @@ static bool llm_load_tensors(
|
||||||
|
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
+ case LLM_ARCH_GEMMA2:
|
||||||
|
+ {
|
||||||
|
+ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||||
|
+
|
||||||
|
+ // output
|
||||||
|
+ model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||||
|
+ model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
|
||||||
|
+
|
||||||
|
+ const int64_t n_ff = hparams.n_ff;
|
||||||
|
+ const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
||||||
|
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||||
|
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||||
|
+
|
||||||
|
+ for (uint32_t i = 0; i < n_layer; ++i) {
|
||||||
|
+ ggml_context * ctx_layer = ctx_for_layer(i);
|
||||||
|
+ ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||||
|
+
|
||||||
|
+ auto & layer = model.layers[i];
|
||||||
|
+
|
||||||
|
+ layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||||
|
+
|
||||||
|
+ layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * hparams.n_head});
|
||||||
|
+ layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
|
||||||
|
+ layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
|
||||||
|
+ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * hparams.n_head, n_embd});
|
||||||
|
+ layer.attn_post_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd});
|
||||||
|
+
|
||||||
|
+ layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||||
|
+ layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
||||||
|
+ layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||||
|
+ layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
|
||||||
|
+ layer.ffn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd});
|
||||||
|
+ }
|
||||||
|
+ } break;
|
||||||
|
case LLM_ARCH_STARCODER2:
|
||||||
|
{
|
||||||
|
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||||
|
@@ -10614,6 +10684,123 @@ struct llm_build_context {
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
|
+ struct ggml_cgraph * build_gemma2() {
|
||||||
|
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||||
|
+
|
||||||
|
+ const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
||||||
|
+
|
||||||
|
+ struct ggml_tensor * cur;
|
||||||
|
+ struct ggml_tensor * inpL;
|
||||||
|
+
|
||||||
|
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||||
|
+
|
||||||
|
+ inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
|
||||||
|
+ cb(inpL, "inp_scaled", -1);
|
||||||
|
+
|
||||||
|
+ // inp_pos - contains the positions
|
||||||
|
+ struct ggml_tensor * inp_pos = build_inp_pos();
|
||||||
|
+
|
||||||
|
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
|
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||||
|
+
|
||||||
|
+ for (int il = 0; il < n_layer; ++il) {
|
||||||
|
+ // norm
|
||||||
|
+ cur = llm_build_norm(ctx0, inpL, hparams,
|
||||||
|
+ model.layers[il].attn_norm, NULL,
|
||||||
|
+ LLM_NORM_RMS, cb, il);
|
||||||
|
+ cb(cur, "attn_norm", il);
|
||||||
|
+
|
||||||
|
+ // self-attention
|
||||||
|
+ {
|
||||||
|
+ // compute Q and K and RoPE them
|
||||||
|
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||||
|
+ cb(Qcur, "Qcur", il);
|
||||||
|
+
|
||||||
|
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
||||||
|
+ cb(Kcur, "Kcur", il);
|
||||||
|
+
|
||||||
|
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
||||||
|
+ cb(Vcur, "Vcur", il);
|
||||||
|
+
|
||||||
|
+ Qcur = ggml_rope_ext(
|
||||||
|
+ ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
|
||||||
|
+ n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
+ ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
|
+ cb(Qcur, "Qcur", il);
|
||||||
|
+
|
||||||
|
+ Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));
|
||||||
|
+ cb(Qcur, "Qcur_scaled", il);
|
||||||
|
+
|
||||||
|
+ Kcur = ggml_rope_ext(
|
||||||
|
+ ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
|
+ n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
+ ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
|
+ cb(Kcur, "Kcur", il);
|
||||||
|
+
|
||||||
|
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
|
+ model.layers[il].wo, NULL,
|
||||||
|
+ Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ if (il == n_layer - 1) {
|
||||||
|
+ // skip computing output for unused tokens
|
||||||
|
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||||
|
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||||
|
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ cur = llm_build_norm(ctx0, cur, hparams,
|
||||||
|
+ model.layers[il].attn_post_norm, NULL,
|
||||||
|
+ LLM_NORM_RMS, cb, il);
|
||||||
|
+ cb(cur, "attn_post_norm", il);
|
||||||
|
+
|
||||||
|
+ struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
|
||||||
|
+ cb(sa_out, "sa_out", il);
|
||||||
|
+
|
||||||
|
+ cur = llm_build_norm(ctx0, sa_out, hparams,
|
||||||
|
+ model.layers[il].ffn_norm, NULL,
|
||||||
|
+ LLM_NORM_RMS, cb, il);
|
||||||
|
+ cb(cur, "ffn_norm", il);
|
||||||
|
+
|
||||||
|
+ // feed-forward network
|
||||||
|
+ {
|
||||||
|
+ cur = llm_build_ffn(ctx0, cur,
|
||||||
|
+ model.layers[il].ffn_up, NULL,
|
||||||
|
+ model.layers[il].ffn_gate, NULL,
|
||||||
|
+ model.layers[il].ffn_down, NULL,
|
||||||
|
+ NULL,
|
||||||
|
+ LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
|
||||||
|
+ cb(cur, "ffn_out", il);
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ cur = llm_build_norm(ctx0, cur, hparams,
|
||||||
|
+ model.layers[il].ffn_post_norm, NULL,
|
||||||
|
+ LLM_NORM_RMS, cb, -1);
|
||||||
|
+ cb(cur, "ffn_post_norm", -1);
|
||||||
|
+
|
||||||
|
+ cur = ggml_add(ctx0, cur, sa_out);
|
||||||
|
+ cb(cur, "l_out", il);
|
||||||
|
+
|
||||||
|
+ // input for next layer
|
||||||
|
+ inpL = cur;
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ cur = inpL;
|
||||||
|
+
|
||||||
|
+ cur = llm_build_norm(ctx0, cur, hparams,
|
||||||
|
+ model.output_norm, NULL,
|
||||||
|
+ LLM_NORM_RMS, cb, -1);
|
||||||
|
+ cb(cur, "result_norm", -1);
|
||||||
|
+
|
||||||
|
+ // lm_head
|
||||||
|
+ cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||||
|
+ cb(cur, "result_output", -1);
|
||||||
|
+
|
||||||
|
+ ggml_build_forward_expand(gf, cur);
|
||||||
|
+
|
||||||
|
+ return gf;
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
struct ggml_cgraph * build_starcoder2() {
|
||||||
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||||
|
|
||||||
|
@@ -11847,6 +12034,10 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
|
{
|
||||||
|
result = llm.build_gemma();
|
||||||
|
} break;
|
||||||
|
+ case LLM_ARCH_GEMMA2:
|
||||||
|
+ {
|
||||||
|
+ result = llm.build_gemma2();
|
||||||
|
+ } break;
|
||||||
|
case LLM_ARCH_STARCODER2:
|
||||||
|
{
|
||||||
|
result = llm.build_starcoder2();
|
||||||
|
@@ -16671,6 +16862,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||||
|
case LLM_ARCH_PHI2:
|
||||||
|
case LLM_ARCH_PHI3:
|
||||||
|
case LLM_ARCH_GEMMA:
|
||||||
|
+ case LLM_ARCH_GEMMA2:
|
||||||
|
case LLM_ARCH_STARCODER2:
|
||||||
|
case LLM_ARCH_GPTNEOX:
|
||||||
|
return LLAMA_ROPE_TYPE_NEOX;
|
||||||
|
@@ -18551,7 +18743,7 @@ static int32_t llama_chat_apply_template_internal(
|
||||||
|
if (add_ass) {
|
||||||
|
ss << "<s>assistant\n";
|
||||||
|
}
|
||||||
|
- } else if (tmpl == "gemma" || tmpl.find("<start_of_turn>") != std::string::npos) {
|
||||||
|
+ } else if (tmpl == "gemma" || tmpl == "gemma2" || tmpl.find("<start_of_turn>") != std::string::npos) {
|
||||||
|
// google/gemma-7b-it
|
||||||
|
std::string system_prompt = "";
|
||||||
|
for (auto message : chat) {
|
||||||
|
--
|
||||||
|
2.45.2
|
||||||
|
|
@ -58,7 +58,7 @@ func availableServers() map[string]string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// glob payloadsDir for files that start with ollama_
|
// glob payloadsDir for files that start with ollama_
|
||||||
pattern := filepath.Join(payloadsDir, "*")
|
pattern := filepath.Join(payloadsDir, "*", "ollama_*")
|
||||||
|
|
||||||
files, err := filepath.Glob(pattern)
|
files, err := filepath.Glob(pattern)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -69,7 +69,7 @@ func availableServers() map[string]string {
|
|||||||
servers := make(map[string]string)
|
servers := make(map[string]string)
|
||||||
for _, file := range files {
|
for _, file := range files {
|
||||||
slog.Debug("availableServers : found", "file", file)
|
slog.Debug("availableServers : found", "file", file)
|
||||||
servers[filepath.Base(file)] = file
|
servers[filepath.Base(filepath.Dir(file))] = filepath.Dir(file)
|
||||||
}
|
}
|
||||||
|
|
||||||
return servers
|
return servers
|
||||||
|
@ -61,7 +61,12 @@ type llmServer struct {
|
|||||||
sem *semaphore.Weighted
|
sem *semaphore.Weighted
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadModel(model string) (*GGML, error) {
|
// LoadModel will load a model from disk. The model must be in the GGML format.
|
||||||
|
//
|
||||||
|
// It collects array values for arrays with a size less than or equal to
|
||||||
|
// maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
|
||||||
|
// the maxArraySize is negative, all arrays are collected.
|
||||||
|
func LoadModel(model string, maxArraySize int) (*GGML, error) {
|
||||||
if _, err := os.Stat(model); err != nil {
|
if _, err := os.Stat(model); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -72,17 +77,27 @@ func LoadModel(model string) (*GGML, error) {
|
|||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
ggml, _, err := DecodeGGML(f)
|
ggml, _, err := DecodeGGML(f, maxArraySize)
|
||||||
return ggml, err
|
return ggml, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewLlamaServer will run a server for the given GPUs
|
// NewLlamaServer will run a server for the given GPUs
|
||||||
// The gpu list must be a single family.
|
// The gpu list must be a single family.
|
||||||
func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) {
|
func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
|
||||||
var err error
|
var err error
|
||||||
var cpuRunner string
|
var cpuRunner string
|
||||||
var estimate MemoryEstimate
|
var estimate MemoryEstimate
|
||||||
var systemMemory uint64
|
var systemTotalMemory uint64
|
||||||
|
var systemFreeMemory uint64
|
||||||
|
|
||||||
|
systemMemInfo, err := gpu.GetCPUMem()
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to lookup system memory", "error", err)
|
||||||
|
} else {
|
||||||
|
systemTotalMemory = systemMemInfo.TotalMemory
|
||||||
|
systemFreeMemory = systemMemInfo.FreeMemory
|
||||||
|
slog.Debug("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", systemFreeMemory)
|
||||||
|
}
|
||||||
|
|
||||||
// If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info
|
// If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info
|
||||||
if opts.NumGPU == 0 {
|
if opts.NumGPU == 0 {
|
||||||
@ -92,19 +107,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
cpuRunner = serverForCpu()
|
cpuRunner = serverForCpu()
|
||||||
estimate = EstimateGPULayers(gpus, ggml, projectors, opts)
|
estimate = EstimateGPULayers(gpus, ggml, projectors, opts)
|
||||||
} else {
|
} else {
|
||||||
if gpus[0].Library == "metal" {
|
|
||||||
memInfo, err := gpu.GetCPUMem()
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("failed to lookup system memory", "error", err)
|
|
||||||
} else {
|
|
||||||
systemMemory = memInfo.TotalMemory
|
|
||||||
slog.Debug("system memory", "total", format.HumanBytes2(systemMemory))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
estimate = EstimateGPULayers(gpus, ggml, projectors, opts)
|
estimate = EstimateGPULayers(gpus, ggml, projectors, opts)
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case gpus[0].Library == "metal" && estimate.VRAMSize > systemMemory:
|
case gpus[0].Library == "metal" && estimate.VRAMSize > systemTotalMemory:
|
||||||
// disable partial offloading when model is greater than total system memory as this
|
// disable partial offloading when model is greater than total system memory as this
|
||||||
// can lead to locking up the system
|
// can lead to locking up the system
|
||||||
opts.NumGPU = 0
|
opts.NumGPU = 0
|
||||||
@ -212,7 +218,12 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Windows CUDA should not use mmap for best performance
|
// Windows CUDA should not use mmap for best performance
|
||||||
if (runtime.GOOS == "windows" && gpus[0].Library == "cuda") || opts.UseMMap == api.TriStateFalse {
|
// Linux with a model larger than free space, mmap leads to thrashing
|
||||||
|
// For CPU loads we want the memory to be allocated, not FS cache
|
||||||
|
if (runtime.GOOS == "windows" && gpus[0].Library == "cuda" && opts.UseMMap == api.TriStateUndefined) ||
|
||||||
|
(runtime.GOOS == "linux" && systemFreeMemory < estimate.TotalSize && opts.UseMMap == api.TriStateUndefined) ||
|
||||||
|
(gpus[0].Library == "cpu" && opts.UseMMap == api.TriStateUndefined) ||
|
||||||
|
opts.UseMMap == api.TriStateFalse {
|
||||||
params = append(params, "--no-mmap")
|
params = append(params, "--no-mmap")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -224,15 +235,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
params = append(params, "--numa")
|
params = append(params, "--numa")
|
||||||
}
|
}
|
||||||
|
|
||||||
numParallel := envconfig.NumParallel
|
|
||||||
|
|
||||||
// TODO (jmorganca): multimodal models don't support parallel yet
|
|
||||||
// see https://github.com/ollama/ollama/issues/4165
|
|
||||||
if len(projectors) > 0 {
|
|
||||||
numParallel = 1
|
|
||||||
slog.Warn("multimodal models don't support parallel requests yet")
|
|
||||||
}
|
|
||||||
|
|
||||||
params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
|
params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
|
||||||
|
|
||||||
if estimate.TensorSplit != "" {
|
if estimate.TensorSplit != "" {
|
||||||
@ -275,8 +277,8 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
pathEnv = "PATH"
|
pathEnv = "PATH"
|
||||||
}
|
}
|
||||||
// prepend the server directory to LD_LIBRARY_PATH/PATH
|
// prepend the server directory to LD_LIBRARY_PATH/PATH and the parent dir for common dependencies
|
||||||
libraryPaths := []string{dir}
|
libraryPaths := []string{dir, filepath.Dir(dir)}
|
||||||
|
|
||||||
if libraryPath, ok := os.LookupEnv(pathEnv); ok {
|
if libraryPath, ok := os.LookupEnv(pathEnv); ok {
|
||||||
// Append our runner directory to the path
|
// Append our runner directory to the path
|
||||||
@ -409,7 +411,7 @@ func projectorMemoryRequirements(filename string) uint64 {
|
|||||||
}
|
}
|
||||||
defer file.Close()
|
defer file.Close()
|
||||||
|
|
||||||
ggml, _, err := DecodeGGML(file)
|
ggml, _, err := DecodeGGML(file, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
@ -559,6 +561,9 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
|||||||
if s.status != nil && s.status.LastErrMsg != "" {
|
if s.status != nil && s.status.LastErrMsg != "" {
|
||||||
msg = s.status.LastErrMsg
|
msg = s.status.LastErrMsg
|
||||||
}
|
}
|
||||||
|
if strings.Contains(msg, "unknown model") {
|
||||||
|
return fmt.Errorf("this model is not supported by your version of Ollama. You may need to upgrade")
|
||||||
|
}
|
||||||
return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
|
return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
@ -25,6 +25,7 @@ var errorPrefixes = []string{
|
|||||||
"CUDA error",
|
"CUDA error",
|
||||||
"cudaMalloc failed",
|
"cudaMalloc failed",
|
||||||
"\"ERR\"",
|
"\"ERR\"",
|
||||||
|
"architecture",
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *StatusWriter) Write(b []byte) (int, error) {
|
func (w *StatusWriter) Write(b []byte) (int, error) {
|
||||||
|
382
openai/openai.go
382
openai/openai.go
@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Error struct {
|
type Error struct {
|
||||||
@ -42,6 +43,12 @@ type ChunkChoice struct {
|
|||||||
FinishReason *string `json:"finish_reason"`
|
FinishReason *string `json:"finish_reason"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CompleteChunkChoice struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
FinishReason *string `json:"finish_reason"`
|
||||||
|
}
|
||||||
|
|
||||||
type Usage struct {
|
type Usage struct {
|
||||||
PromptTokens int `json:"prompt_tokens"`
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
CompletionTokens int `json:"completion_tokens"`
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
@ -85,6 +92,51 @@ type ChatCompletionChunk struct {
|
|||||||
Choices []ChunkChoice `json:"choices"`
|
Choices []ChunkChoice `json:"choices"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
|
||||||
|
type CompletionRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
FrequencyPenalty float32 `json:"frequency_penalty"`
|
||||||
|
MaxTokens *int `json:"max_tokens"`
|
||||||
|
PresencePenalty float32 `json:"presence_penalty"`
|
||||||
|
Seed *int `json:"seed"`
|
||||||
|
Stop any `json:"stop"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
Temperature *float32 `json:"temperature"`
|
||||||
|
TopP float32 `json:"top_p"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Completion struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
SystemFingerprint string `json:"system_fingerprint"`
|
||||||
|
Choices []CompleteChunkChoice `json:"choices"`
|
||||||
|
Usage Usage `json:"usage,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CompletionChunk struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Choices []CompleteChunkChoice `json:"choices"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
SystemFingerprint string `json:"system_fingerprint"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
OwnedBy string `json:"owned_by"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ListCompletion struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Data []Model `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
func NewError(code int, message string) ErrorResponse {
|
func NewError(code int, message string) ErrorResponse {
|
||||||
var etype string
|
var etype string
|
||||||
switch code {
|
switch code {
|
||||||
@ -145,7 +197,79 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func fromRequest(r ChatCompletionRequest) api.ChatRequest {
|
func toCompletion(id string, r api.GenerateResponse) Completion {
|
||||||
|
return Completion{
|
||||||
|
Id: id,
|
||||||
|
Object: "text_completion",
|
||||||
|
Created: r.CreatedAt.Unix(),
|
||||||
|
Model: r.Model,
|
||||||
|
SystemFingerprint: "fp_ollama",
|
||||||
|
Choices: []CompleteChunkChoice{{
|
||||||
|
Text: r.Response,
|
||||||
|
Index: 0,
|
||||||
|
FinishReason: func(reason string) *string {
|
||||||
|
if len(reason) > 0 {
|
||||||
|
return &reason
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}(r.DoneReason),
|
||||||
|
}},
|
||||||
|
Usage: Usage{
|
||||||
|
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
|
||||||
|
PromptTokens: r.PromptEvalCount,
|
||||||
|
CompletionTokens: r.EvalCount,
|
||||||
|
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
|
||||||
|
return CompletionChunk{
|
||||||
|
Id: id,
|
||||||
|
Object: "text_completion",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
Model: r.Model,
|
||||||
|
SystemFingerprint: "fp_ollama",
|
||||||
|
Choices: []CompleteChunkChoice{{
|
||||||
|
Text: r.Response,
|
||||||
|
Index: 0,
|
||||||
|
FinishReason: func(reason string) *string {
|
||||||
|
if len(reason) > 0 {
|
||||||
|
return &reason
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}(r.DoneReason),
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toListCompletion(r api.ListResponse) ListCompletion {
|
||||||
|
var data []Model
|
||||||
|
for _, m := range r.Models {
|
||||||
|
data = append(data, Model{
|
||||||
|
Id: m.Name,
|
||||||
|
Object: "model",
|
||||||
|
Created: m.ModifiedAt.Unix(),
|
||||||
|
OwnedBy: model.ParseName(m.Name).Namespace,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return ListCompletion{
|
||||||
|
Object: "list",
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toModel(r api.ShowResponse, m string) Model {
|
||||||
|
return Model{
|
||||||
|
Id: m,
|
||||||
|
Object: "model",
|
||||||
|
Created: r.ModifiedAt.Unix(),
|
||||||
|
OwnedBy: model.ParseName(m).Namespace,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
|
||||||
var messages []api.Message
|
var messages []api.Message
|
||||||
for _, msg := range r.Messages {
|
for _, msg := range r.Messages {
|
||||||
messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
|
messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
|
||||||
@ -156,7 +280,7 @@ func fromRequest(r ChatCompletionRequest) api.ChatRequest {
|
|||||||
switch stop := r.Stop.(type) {
|
switch stop := r.Stop.(type) {
|
||||||
case string:
|
case string:
|
||||||
options["stop"] = []string{stop}
|
options["stop"] = []string{stop}
|
||||||
case []interface{}:
|
case []any:
|
||||||
var stops []string
|
var stops []string
|
||||||
for _, s := range stop {
|
for _, s := range stop {
|
||||||
if str, ok := s.(string); ok {
|
if str, ok := s.(string); ok {
|
||||||
@ -208,13 +332,78 @@ func fromRequest(r ChatCompletionRequest) api.ChatRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type writer struct {
|
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||||
stream bool
|
options := make(map[string]any)
|
||||||
id string
|
|
||||||
|
switch stop := r.Stop.(type) {
|
||||||
|
case string:
|
||||||
|
options["stop"] = []string{stop}
|
||||||
|
case []string:
|
||||||
|
options["stop"] = stop
|
||||||
|
default:
|
||||||
|
if r.Stop != nil {
|
||||||
|
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", r.Stop)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.MaxTokens != nil {
|
||||||
|
options["num_predict"] = *r.MaxTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Temperature != nil {
|
||||||
|
options["temperature"] = *r.Temperature * 2.0
|
||||||
|
} else {
|
||||||
|
options["temperature"] = 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Seed != nil {
|
||||||
|
options["seed"] = *r.Seed
|
||||||
|
}
|
||||||
|
|
||||||
|
options["frequency_penalty"] = r.FrequencyPenalty * 2.0
|
||||||
|
|
||||||
|
options["presence_penalty"] = r.PresencePenalty * 2.0
|
||||||
|
|
||||||
|
if r.TopP != 0.0 {
|
||||||
|
options["top_p"] = r.TopP
|
||||||
|
} else {
|
||||||
|
options["top_p"] = 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
return api.GenerateRequest{
|
||||||
|
Model: r.Model,
|
||||||
|
Prompt: r.Prompt,
|
||||||
|
Options: options,
|
||||||
|
Stream: &r.Stream,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaseWriter struct {
|
||||||
gin.ResponseWriter
|
gin.ResponseWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *writer) writeError(code int, data []byte) (int, error) {
|
type ChatWriter struct {
|
||||||
|
stream bool
|
||||||
|
id string
|
||||||
|
BaseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
type CompleteWriter struct {
|
||||||
|
stream bool
|
||||||
|
id string
|
||||||
|
BaseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
type ListWriter struct {
|
||||||
|
BaseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
type RetrieveWriter struct {
|
||||||
|
BaseWriter
|
||||||
|
model string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
|
||||||
var serr api.StatusError
|
var serr api.StatusError
|
||||||
err := json.Unmarshal(data, &serr)
|
err := json.Unmarshal(data, &serr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -230,7 +419,7 @@ func (w *writer) writeError(code int, data []byte) (int, error) {
|
|||||||
return len(data), nil
|
return len(data), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *writer) writeResponse(data []byte) (int, error) {
|
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||||
var chatResponse api.ChatResponse
|
var chatResponse api.ChatResponse
|
||||||
err := json.Unmarshal(data, &chatResponse)
|
err := json.Unmarshal(data, &chatResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -270,7 +459,7 @@ func (w *writer) writeResponse(data []byte) (int, error) {
|
|||||||
return len(data), nil
|
return len(data), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *writer) Write(data []byte) (int, error) {
|
func (w *ChatWriter) Write(data []byte) (int, error) {
|
||||||
code := w.ResponseWriter.Status()
|
code := w.ResponseWriter.Status()
|
||||||
if code != http.StatusOK {
|
if code != http.StatusOK {
|
||||||
return w.writeError(code, data)
|
return w.writeError(code, data)
|
||||||
@ -279,7 +468,176 @@ func (w *writer) Write(data []byte) (int, error) {
|
|||||||
return w.writeResponse(data)
|
return w.writeResponse(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Middleware() gin.HandlerFunc {
|
func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
||||||
|
var generateResponse api.GenerateResponse
|
||||||
|
err := json.Unmarshal(data, &generateResponse)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// completion chunk
|
||||||
|
if w.stream {
|
||||||
|
d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if generateResponse.Done {
|
||||||
|
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// completion
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *CompleteWriter) Write(data []byte) (int, error) {
|
||||||
|
code := w.ResponseWriter.Status()
|
||||||
|
if code != http.StatusOK {
|
||||||
|
return w.writeError(code, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.writeResponse(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ListWriter) writeResponse(data []byte) (int, error) {
|
||||||
|
var listResponse api.ListResponse
|
||||||
|
err := json.Unmarshal(data, &listResponse)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ListWriter) Write(data []byte) (int, error) {
|
||||||
|
code := w.ResponseWriter.Status()
|
||||||
|
if code != http.StatusOK {
|
||||||
|
return w.writeError(code, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.writeResponse(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
|
||||||
|
var showResponse api.ShowResponse
|
||||||
|
err := json.Unmarshal(data, &showResponse)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// retrieve completion
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
||||||
|
code := w.ResponseWriter.Status()
|
||||||
|
if code != http.StatusOK {
|
||||||
|
return w.writeError(code, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.writeResponse(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
w := &ListWriter{
|
||||||
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer = w
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func RetrieveMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
|
// response writer
|
||||||
|
w := &RetrieveWriter{
|
||||||
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
|
model: c.Param("model"),
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer = w
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func CompletionsMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
var req CompletionRequest
|
||||||
|
err := c.ShouldBindJSON(&req)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
genReq, err := fromCompleteRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
|
w := &CompleteWriter{
|
||||||
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
|
stream: req.Stream,
|
||||||
|
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer = w
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ChatMiddleware() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
var req ChatCompletionRequest
|
var req ChatCompletionRequest
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
@ -294,15 +652,15 @@ func Middleware() gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if err := json.NewEncoder(&b).Encode(fromRequest(req)); err != nil {
|
if err := json.NewEncoder(&b).Encode(fromChatRequest(req)); err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Request.Body = io.NopCloser(&b)
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
w := &writer{
|
w := &ChatWriter{
|
||||||
ResponseWriter: c.Writer,
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
stream: req.Stream,
|
stream: req.Stream,
|
||||||
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
||||||
}
|
}
|
||||||
|
298
openai/openai_test.go
Normal file
298
openai/openai_test.go
Normal file
@ -0,0 +1,298 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
Name string
|
||||||
|
Method string
|
||||||
|
Path string
|
||||||
|
TestPath string
|
||||||
|
Handler func() gin.HandlerFunc
|
||||||
|
Endpoint func(c *gin.Context)
|
||||||
|
Setup func(t *testing.T, req *http.Request)
|
||||||
|
Expected func(t *testing.T, resp *httptest.ResponseRecorder)
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
Name: "chat handler",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/chat",
|
||||||
|
TestPath: "/api/chat",
|
||||||
|
Handler: ChatMiddleware,
|
||||||
|
Endpoint: func(c *gin.Context) {
|
||||||
|
var chatReq api.ChatRequest
|
||||||
|
if err := c.ShouldBindJSON(&chatReq); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userMessage := chatReq.Messages[0].Content
|
||||||
|
var assistantMessage string
|
||||||
|
|
||||||
|
switch userMessage {
|
||||||
|
case "Hello":
|
||||||
|
assistantMessage = "Hello!"
|
||||||
|
default:
|
||||||
|
assistantMessage = "I'm not sure how to respond to that."
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, api.ChatResponse{
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: assistantMessage,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
},
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := ChatCompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []Message{{Role: "user", Content: "Hello"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||||
|
assert.Equal(t, http.StatusOK, resp.Code)
|
||||||
|
|
||||||
|
var chatResp ChatCompletion
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if chatResp.Object != "chat.completion" {
|
||||||
|
t.Fatalf("expected chat.completion, got %s", chatResp.Object)
|
||||||
|
}
|
||||||
|
|
||||||
|
if chatResp.Choices[0].Message.Content != "Hello!" {
|
||||||
|
t.Fatalf("expected Hello!, got %s", chatResp.Choices[0].Message.Content)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "completions handler",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/generate",
|
||||||
|
TestPath: "/api/generate",
|
||||||
|
Handler: CompletionsMiddleware,
|
||||||
|
Endpoint: func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||||
|
Response: "Hello!",
|
||||||
|
})
|
||||||
|
},
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := CompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||||
|
assert.Equal(t, http.StatusOK, resp.Code)
|
||||||
|
var completionResp Completion
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if completionResp.Object != "text_completion" {
|
||||||
|
t.Fatalf("expected text_completion, got %s", completionResp.Object)
|
||||||
|
}
|
||||||
|
|
||||||
|
if completionResp.Choices[0].Text != "Hello!" {
|
||||||
|
t.Fatalf("expected Hello!, got %s", completionResp.Choices[0].Text)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "completions handler with params",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/generate",
|
||||||
|
TestPath: "/api/generate",
|
||||||
|
Handler: CompletionsMiddleware,
|
||||||
|
Endpoint: func(c *gin.Context) {
|
||||||
|
var generateReq api.GenerateRequest
|
||||||
|
if err := c.ShouldBindJSON(&generateReq); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
temperature := generateReq.Options["temperature"].(float64)
|
||||||
|
var assistantMessage string
|
||||||
|
|
||||||
|
switch temperature {
|
||||||
|
case 1.6:
|
||||||
|
assistantMessage = "Received temperature of 1.6"
|
||||||
|
default:
|
||||||
|
assistantMessage = fmt.Sprintf("Received temperature of %f", temperature)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||||
|
Response: assistantMessage,
|
||||||
|
})
|
||||||
|
},
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
temp := float32(0.8)
|
||||||
|
body := CompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
Temperature: &temp,
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||||
|
assert.Equal(t, http.StatusOK, resp.Code)
|
||||||
|
var completionResp Completion
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if completionResp.Object != "text_completion" {
|
||||||
|
t.Fatalf("expected text_completion, got %s", completionResp.Object)
|
||||||
|
}
|
||||||
|
|
||||||
|
if completionResp.Choices[0].Text != "Received temperature of 1.6" {
|
||||||
|
t.Fatalf("expected Received temperature of 1.6, got %s", completionResp.Choices[0].Text)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "completions handler with error",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/generate",
|
||||||
|
TestPath: "/api/generate",
|
||||||
|
Handler: CompletionsMiddleware,
|
||||||
|
Endpoint: func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
||||||
|
},
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := CompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||||
|
if resp.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", resp.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(resp.Body.String(), `"invalid request"`) {
|
||||||
|
t.Fatalf("error was not forwarded")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "list handler",
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/api/tags",
|
||||||
|
TestPath: "/api/tags",
|
||||||
|
Handler: ListMiddleware,
|
||||||
|
Endpoint: func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, api.ListResponse{
|
||||||
|
Models: []api.ListModelResponse{
|
||||||
|
{
|
||||||
|
Name: "Test Model",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||||
|
assert.Equal(t, http.StatusOK, resp.Code)
|
||||||
|
|
||||||
|
var listResp ListCompletion
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if listResp.Object != "list" {
|
||||||
|
t.Fatalf("expected list, got %s", listResp.Object)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(listResp.Data) != 1 {
|
||||||
|
t.Fatalf("expected 1, got %d", len(listResp.Data))
|
||||||
|
}
|
||||||
|
|
||||||
|
if listResp.Data[0].Id != "Test Model" {
|
||||||
|
t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "retrieve model",
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/api/show/:model",
|
||||||
|
TestPath: "/api/show/test-model",
|
||||||
|
Handler: RetrieveMiddleware,
|
||||||
|
Endpoint: func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, api.ShowResponse{
|
||||||
|
ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
|
||||||
|
})
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||||
|
var retrieveResp Model
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if retrieveResp.Object != "model" {
|
||||||
|
t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
|
||||||
|
}
|
||||||
|
|
||||||
|
if retrieveResp.Id != "test-model" {
|
||||||
|
t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
|
router = gin.New()
|
||||||
|
router.Use(tc.Handler())
|
||||||
|
router.Handle(tc.Method, tc.Path, tc.Endpoint)
|
||||||
|
req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
|
||||||
|
|
||||||
|
if tc.Setup != nil {
|
||||||
|
tc.Setup(t, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
tc.Expected(t, resp)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -124,7 +124,7 @@ func ParseFile(r io.Reader) (*File, error) {
|
|||||||
case stateComment, stateNil:
|
case stateComment, stateNil:
|
||||||
// pass
|
// pass
|
||||||
case stateValue:
|
case stateValue:
|
||||||
s, ok := unquote(b.String())
|
s, ok := unquote(strings.TrimSpace(b.String()))
|
||||||
if !ok || isSpace(r) {
|
if !ok || isSpace(r) {
|
||||||
if _, err := b.WriteRune(r); err != nil {
|
if _, err := b.WriteRune(r); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -158,7 +158,7 @@ func ParseFile(r io.Reader) (*File, error) {
|
|||||||
case stateComment, stateNil:
|
case stateComment, stateNil:
|
||||||
// pass; nothing to flush
|
// pass; nothing to flush
|
||||||
case stateValue:
|
case stateValue:
|
||||||
s, ok := unquote(b.String())
|
s, ok := unquote(strings.TrimSpace(b.String()))
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, io.ErrUnexpectedEOF
|
return nil, io.ErrUnexpectedEOF
|
||||||
}
|
}
|
||||||
|
@ -22,7 +22,13 @@ ADAPTER adapter1
|
|||||||
LICENSE MIT
|
LICENSE MIT
|
||||||
PARAMETER param1 value1
|
PARAMETER param1 value1
|
||||||
PARAMETER param2 value2
|
PARAMETER param2 value2
|
||||||
TEMPLATE template1
|
TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
{{ .Response }}<|eot_id|>"""
|
||||||
`
|
`
|
||||||
|
|
||||||
reader := strings.NewReader(input)
|
reader := strings.NewReader(input)
|
||||||
@ -36,7 +42,40 @@ TEMPLATE template1
|
|||||||
{Name: "license", Args: "MIT"},
|
{Name: "license", Args: "MIT"},
|
||||||
{Name: "param1", Args: "value1"},
|
{Name: "param1", Args: "value1"},
|
||||||
{Name: "param2", Args: "value2"},
|
{Name: "param2", Args: "value2"},
|
||||||
{Name: "template", Args: "template1"},
|
{Name: "template", Args: "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|>"},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, expectedCommands, modelfile.Commands)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFileTrimSpace(t *testing.T) {
|
||||||
|
input := `
|
||||||
|
FROM " model 1"
|
||||||
|
ADAPTER adapter3
|
||||||
|
LICENSE "MIT "
|
||||||
|
PARAMETER param1 value1
|
||||||
|
PARAMETER param2 value2
|
||||||
|
TEMPLATE """ {{ if .System }}<|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
{{ .Response }}<|eot_id|> """
|
||||||
|
`
|
||||||
|
|
||||||
|
reader := strings.NewReader(input)
|
||||||
|
|
||||||
|
modelfile, err := ParseFile(reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
expectedCommands := []Command{
|
||||||
|
{Name: "model", Args: " model 1"},
|
||||||
|
{Name: "adapter", Args: "adapter3"},
|
||||||
|
{Name: "license", Args: "MIT "},
|
||||||
|
{Name: "param1", Args: "value1"},
|
||||||
|
{Name: "param2", Args: "value2"},
|
||||||
|
{Name: "template", Args: " {{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|> "},
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, expectedCommands, modelfile.Commands)
|
assert.Equal(t, expectedCommands, modelfile.Commands)
|
||||||
@ -48,6 +87,26 @@ func TestParseFileFrom(t *testing.T) {
|
|||||||
expected []Command
|
expected []Command
|
||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
|
{
|
||||||
|
"FROM \"FOO BAR \"",
|
||||||
|
[]Command{{Name: "model", Args: "FOO BAR "}},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"FROM \"FOO BAR\"\nPARAMETER param1 value1",
|
||||||
|
[]Command{{Name: "model", Args: "FOO BAR"}, {Name: "param1", Args: "value1"}},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"FROM FOOO BAR ",
|
||||||
|
[]Command{{Name: "model", Args: "FOOO BAR"}},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"FROM /what/is/the path ",
|
||||||
|
[]Command{{Name: "model", Args: "/what/is/the path"}},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"FROM foo",
|
"FROM foo",
|
||||||
[]Command{{Name: "model", Args: "foo"}},
|
[]Command{{Name: "model", Args: "foo"}},
|
||||||
@ -86,6 +145,11 @@ func TestParseFileFrom(t *testing.T) {
|
|||||||
[]Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}},
|
[]Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}},
|
||||||
nil,
|
nil,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"PARAMETER what the \nFROM lemons make lemonade ",
|
||||||
|
[]Command{{Name: "what", Args: "the"}, {Name: "model", Args: "lemons make lemonade"}},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, c := range cases {
|
for _, c := range cases {
|
||||||
@ -399,7 +463,7 @@ func TestParseFileParameters(t *testing.T) {
|
|||||||
"mirostat_eta 1.0": {"mirostat_eta", "1.0"},
|
"mirostat_eta 1.0": {"mirostat_eta", "1.0"},
|
||||||
"penalize_newline true": {"penalize_newline", "true"},
|
"penalize_newline true": {"penalize_newline", "true"},
|
||||||
"stop ### User:": {"stop", "### User:"},
|
"stop ### User:": {"stop", "### User:"},
|
||||||
"stop ### User: ": {"stop", "### User: "},
|
"stop ### User: ": {"stop", "### User:"},
|
||||||
"stop \"### User:\"": {"stop", "### User:"},
|
"stop \"### User:\"": {"stop", "### User:"},
|
||||||
"stop \"### User: \"": {"stop", "### User: "},
|
"stop \"### User: \"": {"stop", "### User: "},
|
||||||
"stop \"\"\"### User:\"\"\"": {"stop", "### User:"},
|
"stop \"\"\"### User:\"\"\"": {"stop", "### User:"},
|
||||||
|
@ -103,19 +103,19 @@ function buildApp() {
|
|||||||
function gatherDependencies() {
|
function gatherDependencies() {
|
||||||
write-host "Gathering runtime dependencies"
|
write-host "Gathering runtime dependencies"
|
||||||
cd "${script:SRC_DIR}"
|
cd "${script:SRC_DIR}"
|
||||||
md "${script:DEPS_DIR}" -ea 0 > $null
|
md "${script:DEPS_DIR}\ollama_runners" -ea 0 > $null
|
||||||
|
|
||||||
# TODO - this varies based on host build system and MSVC version - drive from dumpbin output
|
# TODO - this varies based on host build system and MSVC version - drive from dumpbin output
|
||||||
# currently works for Win11 + MSVC 2019 + Cuda V11
|
# currently works for Win11 + MSVC 2019 + Cuda V11
|
||||||
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140.dll" "${script:DEPS_DIR}\"
|
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140.dll" "${script:DEPS_DIR}\ollama_runners\"
|
||||||
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\"
|
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\ollama_runners\"
|
||||||
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\"
|
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\ollama_runners\"
|
||||||
|
|
||||||
|
|
||||||
cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\"
|
cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\"
|
||||||
if ("${env:KEY_CONTAINER}") {
|
if ("${env:KEY_CONTAINER}") {
|
||||||
write-host "about to sign"
|
write-host "about to sign"
|
||||||
foreach ($file in (get-childitem "${script:DEPS_DIR}/cu*.dll") + @("${script:SRC_DIR}\dist\ollama_welcome.ps1")){
|
foreach ($file in (get-childitem "${script:DEPS_DIR}\cuda\cu*.dll") + @("${script:SRC_DIR}\dist\ollama_welcome.ps1")){
|
||||||
write-host "signing $file"
|
write-host "signing $file"
|
||||||
& "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
|
& "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
|
||||||
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} $file
|
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} $file
|
||||||
|
@ -279,7 +279,7 @@ if ! check_gpu nvidia-smi || [ -z "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\
|
|||||||
case $OS_NAME in
|
case $OS_NAME in
|
||||||
centos|rhel) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -d '.' -f 1) ;;
|
centos|rhel) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -d '.' -f 1) ;;
|
||||||
rocky) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -c1) ;;
|
rocky) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -c1) ;;
|
||||||
fedora) [ $OS_VERSION -lt '37' ] && install_cuda_driver_yum $OS_NAME $OS_VERSION || install_cuda_driver_yum $OS_NAME '37';;
|
fedora) [ $OS_VERSION -lt '39' ] && install_cuda_driver_yum $OS_NAME $OS_VERSION || install_cuda_driver_yum $OS_NAME '39';;
|
||||||
amzn) install_cuda_driver_yum 'fedora' '37' ;;
|
amzn) install_cuda_driver_yum 'fedora' '37' ;;
|
||||||
debian) install_cuda_driver_apt $OS_NAME $OS_VERSION ;;
|
debian) install_cuda_driver_apt $OS_NAME $OS_VERSION ;;
|
||||||
ubuntu) install_cuda_driver_apt $OS_NAME $(echo $OS_VERSION | sed 's/\.//') ;;
|
ubuntu) install_cuda_driver_apt $OS_NAME $(echo $OS_VERSION | sed 's/\.//') ;;
|
||||||
|
@ -6,10 +6,21 @@ set -ex
|
|||||||
MACHINE=$(uname -m)
|
MACHINE=$(uname -m)
|
||||||
|
|
||||||
if grep -i "centos" /etc/system-release >/dev/null; then
|
if grep -i "centos" /etc/system-release >/dev/null; then
|
||||||
|
# As of 7/1/2024 mirrorlist.centos.org has been taken offline, so adjust accordingly
|
||||||
|
sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo
|
||||||
|
sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo
|
||||||
|
sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo
|
||||||
|
|
||||||
# Centos 7 derivatives have too old of a git version to run our generate script
|
# Centos 7 derivatives have too old of a git version to run our generate script
|
||||||
# uninstall and ignore failures
|
# uninstall and ignore failures
|
||||||
yum remove -y git
|
yum remove -y git
|
||||||
yum -y install epel-release centos-release-scl
|
yum -y install epel-release centos-release-scl
|
||||||
|
|
||||||
|
# The release packages reinstate the mirrors, undo that again
|
||||||
|
sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo
|
||||||
|
sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo
|
||||||
|
sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo
|
||||||
|
|
||||||
yum -y install dnf
|
yum -y install dnf
|
||||||
if [ "${MACHINE}" = "x86_64" ]; then
|
if [ "${MACHINE}" = "x86_64" ]; then
|
||||||
yum -y install https://repo.ius.io/ius-release-el7.rpm
|
yum -y install https://repo.ius.io/ius-release-el7.rpm
|
||||||
|
@ -28,11 +28,16 @@ import (
|
|||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Capability string
|
||||||
|
|
||||||
|
const CapabilityCompletion = Capability("completion")
|
||||||
|
|
||||||
type registryOptions struct {
|
type registryOptions struct {
|
||||||
Insecure bool
|
Insecure bool
|
||||||
Username string
|
Username string
|
||||||
@ -48,16 +53,43 @@ type Model struct {
|
|||||||
ParentModel string
|
ParentModel string
|
||||||
AdapterPaths []string
|
AdapterPaths []string
|
||||||
ProjectorPaths []string
|
ProjectorPaths []string
|
||||||
Template string
|
|
||||||
System string
|
System string
|
||||||
License []string
|
License []string
|
||||||
Digest string
|
Digest string
|
||||||
Options map[string]interface{}
|
Options map[string]interface{}
|
||||||
Messages []Message
|
Messages []Message
|
||||||
|
|
||||||
|
Template *template.Template
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) IsEmbedding() bool {
|
func (m *Model) Has(caps ...Capability) bool {
|
||||||
return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert")
|
for _, cap := range caps {
|
||||||
|
switch cap {
|
||||||
|
case CapabilityCompletion:
|
||||||
|
f, err := os.Open(m.ModelPath)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("couldn't open model file", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
// TODO(mxyng): decode the GGML into model to avoid doing this multiple times
|
||||||
|
ggml, _, err := llm.DecodeGGML(f, 0)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("couldn't decode ggml", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
slog.Error("unknown capability", "capability", cap)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) String() string {
|
func (m *Model) String() string {
|
||||||
@ -82,10 +114,10 @@ func (m *Model) String() string {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.Template != "" {
|
if m.Template != nil {
|
||||||
modelfile.Commands = append(modelfile.Commands, parser.Command{
|
modelfile.Commands = append(modelfile.Commands, parser.Command{
|
||||||
Name: "template",
|
Name: "template",
|
||||||
Args: m.Template,
|
Args: m.Template.String(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -135,13 +167,6 @@ type Message struct {
|
|||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ManifestV2 struct {
|
|
||||||
SchemaVersion int `json:"schemaVersion"`
|
|
||||||
MediaType string `json:"mediaType"`
|
|
||||||
Config *Layer `json:"config"`
|
|
||||||
Layers []*Layer `json:"layers"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ConfigV2 struct {
|
type ConfigV2 struct {
|
||||||
ModelFormat string `json:"model_format"`
|
ModelFormat string `json:"model_format"`
|
||||||
ModelFamily string `json:"model_family"`
|
ModelFamily string `json:"model_family"`
|
||||||
@ -160,7 +185,7 @@ type RootFS struct {
|
|||||||
DiffIDs []string `json:"diff_ids"`
|
DiffIDs []string `json:"diff_ids"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
|
func GetManifest(mp ModelPath) (*Manifest, string, error) {
|
||||||
fp, err := mp.GetManifestPath()
|
fp, err := mp.GetManifestPath()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
@ -170,7 +195,7 @@ func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
|
|||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
var manifest *ManifestV2
|
var manifest *Manifest
|
||||||
|
|
||||||
bts, err := os.ReadFile(fp)
|
bts, err := os.ReadFile(fp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -198,8 +223,7 @@ func GetModel(name string) (*Model, error) {
|
|||||||
Name: mp.GetFullTagname(),
|
Name: mp.GetFullTagname(),
|
||||||
ShortName: mp.GetShortTagname(),
|
ShortName: mp.GetShortTagname(),
|
||||||
Digest: digest,
|
Digest: digest,
|
||||||
Template: "{{ .Prompt }}",
|
Template: template.DefaultTemplate,
|
||||||
License: []string{},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
filename, err := GetBlobsPath(manifest.Config.Digest)
|
filename, err := GetBlobsPath(manifest.Config.Digest)
|
||||||
@ -235,13 +259,17 @@ func GetModel(name string) (*Model, error) {
|
|||||||
model.AdapterPaths = append(model.AdapterPaths, filename)
|
model.AdapterPaths = append(model.AdapterPaths, filename)
|
||||||
case "application/vnd.ollama.image.projector":
|
case "application/vnd.ollama.image.projector":
|
||||||
model.ProjectorPaths = append(model.ProjectorPaths, filename)
|
model.ProjectorPaths = append(model.ProjectorPaths, filename)
|
||||||
case "application/vnd.ollama.image.template":
|
case "application/vnd.ollama.image.prompt",
|
||||||
|
"application/vnd.ollama.image.template":
|
||||||
bts, err := os.ReadFile(filename)
|
bts, err := os.ReadFile(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
model.Template = string(bts)
|
model.Template, err = template.Parse(string(bts))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
case "application/vnd.ollama.image.system":
|
case "application/vnd.ollama.image.system":
|
||||||
bts, err := os.ReadFile(filename)
|
bts, err := os.ReadFile(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -249,13 +277,6 @@ func GetModel(name string) (*Model, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
model.System = string(bts)
|
model.System = string(bts)
|
||||||
case "application/vnd.ollama.image.prompt":
|
|
||||||
bts, err := os.ReadFile(filename)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
model.Template = string(bts)
|
|
||||||
case "application/vnd.ollama.image.params":
|
case "application/vnd.ollama.image.params":
|
||||||
params, err := os.Open(filename)
|
params, err := os.Open(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -414,17 +435,22 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
layers, err := parseFromFile(ctx, temp, "", fn)
|
layer, err := NewLayer(temp, baseLayer.MediaType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(layers) != 1 {
|
if _, err := temp.Seek(0, io.SeekStart); err != nil {
|
||||||
return errors.New("quantization failed")
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
baseLayer.Layer = layers[0].Layer
|
ggml, _, err := llm.DecodeGGML(temp, 0)
|
||||||
baseLayer.GGML = layers[0].GGML
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
baseLayer.Layer = layer
|
||||||
|
baseLayer.GGML = ggml
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -817,7 +843,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
|||||||
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||||
mp := ParseModelPath(name)
|
mp := ParseModelPath(name)
|
||||||
|
|
||||||
var manifest *ManifestV2
|
var manifest *Manifest
|
||||||
var err error
|
var err error
|
||||||
var noprune string
|
var noprune string
|
||||||
|
|
||||||
@ -924,7 +950,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*ManifestV2, error) {
|
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
|
||||||
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
||||||
|
|
||||||
headers := make(http.Header)
|
headers := make(http.Header)
|
||||||
@ -935,7 +961,7 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptio
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
var m *ManifestV2
|
var m *Manifest
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Manifest struct {
|
type Manifest struct {
|
||||||
ManifestV2
|
SchemaVersion int `json:"schemaVersion"`
|
||||||
|
MediaType string `json:"mediaType"`
|
||||||
|
Config *Layer `json:"config"`
|
||||||
|
Layers []*Layer `json:"layers"`
|
||||||
|
|
||||||
filepath string
|
filepath string
|
||||||
fi os.FileInfo
|
fi os.FileInfo
|
||||||
@ -66,7 +69,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
|||||||
|
|
||||||
p := filepath.Join(manifests, n.Filepath())
|
p := filepath.Join(manifests, n.Filepath())
|
||||||
|
|
||||||
var m ManifestV2
|
var m Manifest
|
||||||
f, err := os.Open(p)
|
f, err := os.Open(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -83,12 +86,11 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Manifest{
|
m.filepath = p
|
||||||
ManifestV2: m,
|
m.fi = fi
|
||||||
filepath: p,
|
m.digest = fmt.Sprintf("%x", sha256sum.Sum(nil))
|
||||||
fi: fi,
|
|
||||||
digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
|
return &m, nil
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
|
func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
|
||||||
@ -108,7 +110,7 @@ func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
|
|||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
m := ManifestV2{
|
m := Manifest{
|
||||||
SchemaVersion: 2,
|
SchemaVersion: 2,
|
||||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||||
Config: config,
|
Config: config,
|
||||||
|
@ -25,7 +25,7 @@ func createManifest(t *testing.T, path, name string) {
|
|||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
if err := json.NewEncoder(f).Encode(ManifestV2{}); err != nil {
|
if err := json.NewEncoder(f).Encode(Manifest{}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -15,7 +15,7 @@ import (
|
|||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/convert"
|
"github.com/ollama/ollama/convert"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/templates"
|
"github.com/ollama/ollama/template"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -63,7 +63,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
|||||||
}
|
}
|
||||||
defer blob.Close()
|
defer blob.Close()
|
||||||
|
|
||||||
ggml, _, err := llm.DecodeGGML(blob)
|
ggml, _, err := llm.DecodeGGML(blob, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -77,62 +77,79 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
|||||||
return layers, nil
|
return layers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
|
func extractFromZipFile(p string, file *os.File, fn func(api.ProgressResponse)) error {
|
||||||
stat, err := file.Stat()
|
stat, err := file.Stat()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := zip.NewReader(file, stat.Size())
|
r, err := zip.NewReader(file, stat.Size())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
tempdir, err := os.MkdirTemp(filepath.Dir(file.Name()), "")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer os.RemoveAll(tempdir)
|
|
||||||
|
|
||||||
fn(api.ProgressResponse{Status: "unpacking model metadata"})
|
fn(api.ProgressResponse{Status: "unpacking model metadata"})
|
||||||
for _, f := range r.File {
|
for _, f := range r.File {
|
||||||
|
if !filepath.IsLocal(f.Name) {
|
||||||
|
return fmt.Errorf("%w: %s", zip.ErrInsecurePath, f.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
n := filepath.Join(p, f.Name)
|
||||||
|
if err := os.MkdirAll(filepath.Dir(n), 0o750); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(mxyng): this should not write out all files to disk
|
// TODO(mxyng): this should not write out all files to disk
|
||||||
outfile, err := os.Create(filepath.Join(tempdir, f.Name))
|
outfile, err := os.Create(n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
defer outfile.Close()
|
defer outfile.Close()
|
||||||
|
|
||||||
infile, err := f.Open()
|
infile, err := f.Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
defer infile.Close()
|
defer infile.Close()
|
||||||
|
|
||||||
if _, err = io.Copy(outfile, infile); err != nil {
|
if _, err = io.Copy(outfile, infile); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := outfile.Close(); err != nil {
|
if err := outfile.Close(); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := infile.Close(); err != nil {
|
if err := infile.Close(); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mf, err := convert.GetModelFormat(tempdir)
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
|
||||||
|
tempDir, err := os.MkdirTemp(filepath.Dir(file.Name()), "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
if err := extractFromZipFile(tempDir, file, fn); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
mf, err := convert.GetModelFormat(tempDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
params, err := mf.GetParams(tempdir)
|
params, err := mf.GetParams(tempDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
mArch, err := mf.GetModelArch("", tempdir, params)
|
mArch, err := mf.GetModelArch("", tempDir, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -150,7 +167,7 @@ func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(a
|
|||||||
|
|
||||||
// TODO(mxyng): this should write directly into a layer
|
// TODO(mxyng): this should write directly into a layer
|
||||||
// e.g. NewLayer(arch.Reader(), "application/vnd.ollama.image.model")
|
// e.g. NewLayer(arch.Reader(), "application/vnd.ollama.image.model")
|
||||||
temp, err := os.CreateTemp(tempdir, "fp16")
|
temp, err := os.CreateTemp(tempDir, "fp16")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -176,7 +193,7 @@ func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(a
|
|||||||
}
|
}
|
||||||
defer bin.Close()
|
defer bin.Close()
|
||||||
|
|
||||||
ggml, _, err := llm.DecodeGGML(bin)
|
ggml, _, err := llm.DecodeGGML(bin, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -210,7 +227,7 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
|
|||||||
|
|
||||||
var offset int64
|
var offset int64
|
||||||
for offset < stat.Size() {
|
for offset < stat.Size() {
|
||||||
ggml, n, err := llm.DecodeGGML(file)
|
ggml, n, err := llm.DecodeGGML(file, 0)
|
||||||
if errors.Is(err, io.EOF) {
|
if errors.Is(err, io.EOF) {
|
||||||
break
|
break
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
@ -239,7 +256,7 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
|
|||||||
func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
||||||
for _, layer := range layers {
|
for _, layer := range layers {
|
||||||
if s := layer.GGML.KV().ChatTemplate(); s != "" {
|
if s := layer.GGML.KV().ChatTemplate(); s != "" {
|
||||||
if t, err := templates.NamedTemplate(s); err != nil {
|
if t, err := template.Named(s); err != nil {
|
||||||
slog.Debug("template detection", "error", err)
|
slog.Debug("template detection", "error", err)
|
||||||
} else {
|
} else {
|
||||||
tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||||
|
112
server/model_test.go
Normal file
112
server/model_test.go
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"archive/zip"
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func createZipFile(t *testing.T, name string) *os.File {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
f, err := os.CreateTemp(t.TempDir(), "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
zf := zip.NewWriter(f)
|
||||||
|
defer zf.Close()
|
||||||
|
|
||||||
|
zh, err := zf.CreateHeader(&zip.FileHeader{Name: name})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := io.Copy(zh, bytes.NewReader([]byte(""))); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractFromZipFile(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
expect []string
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "good",
|
||||||
|
expect: []string{"good"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: strings.Join([]string{"path", "..", "to", "good"}, string(os.PathSeparator)),
|
||||||
|
expect: []string{filepath.Join("to", "good")},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: strings.Join([]string{"path", "..", "to", "..", "good"}, string(os.PathSeparator)),
|
||||||
|
expect: []string{"good"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: strings.Join([]string{"path", "to", "..", "..", "good"}, string(os.PathSeparator)),
|
||||||
|
expect: []string{"good"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: strings.Join([]string{"..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "bad"}, string(os.PathSeparator)),
|
||||||
|
err: zip.ErrInsecurePath,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: strings.Join([]string{"path", "..", "..", "to", "bad"}, string(os.PathSeparator)),
|
||||||
|
err: zip.ErrInsecurePath,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
f := createZipFile(t, tt.name)
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
if err := extractFromZipFile(tempDir, f, func(api.ProgressResponse) {}); !errors.Is(err, tt.err) {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var matches []string
|
||||||
|
if err := filepath.Walk(tempDir, func(p string, fi os.FileInfo, err error) error {
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !fi.IsDir() {
|
||||||
|
matches = append(matches, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var actual []string
|
||||||
|
for _, match := range matches {
|
||||||
|
rel, err := filepath.Rel(tempDir, match)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
actual = append(actual, rel)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !slices.Equal(actual, tt.expect) {
|
||||||
|
t.Fatalf("expected %d files, got %d", len(tt.expect), len(matches))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -4,10 +4,11 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
|
||||||
"text/template/parse"
|
"text/template/parse"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
// isResponseNode checks if the node contains .Response
|
// isResponseNode checks if the node contains .Response
|
||||||
@ -53,13 +54,8 @@ func formatTemplateForResponse(tmpl *template.Template, generate bool) {
|
|||||||
|
|
||||||
// Prompt renders a prompt from a template. If generate is set to true,
|
// Prompt renders a prompt from a template. If generate is set to true,
|
||||||
// the response and parts of the template following it are not rendered
|
// the response and parts of the template following it are not rendered
|
||||||
func Prompt(tmpl, system, prompt, response string, generate bool) (string, error) {
|
func Prompt(tmpl *template.Template, system, prompt, response string, generate bool) (string, error) {
|
||||||
parsed, err := template.New("").Option("missingkey=zero").Parse(tmpl)
|
formatTemplateForResponse(tmpl, generate)
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
formatTemplateForResponse(parsed, generate)
|
|
||||||
|
|
||||||
vars := map[string]any{
|
vars := map[string]any{
|
||||||
"System": system,
|
"System": system,
|
||||||
@ -68,14 +64,14 @@ func Prompt(tmpl, system, prompt, response string, generate bool) (string, error
|
|||||||
}
|
}
|
||||||
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
if err := parsed.Execute(&sb, vars); err != nil {
|
if err := tmpl.Execute(&sb, vars); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return sb.String(), nil
|
return sb.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func countTokens(tmpl string, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
|
func countTokens(tmpl *template.Template, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
|
||||||
rendered, err := Prompt(tmpl, system, prompt, response, false)
|
rendered, err := Prompt(tmpl, system, prompt, response, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
@ -91,7 +87,7 @@ func countTokens(tmpl string, system string, prompt string, response string, enc
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
|
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
|
||||||
func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
|
func ChatPrompt(tmpl *template.Template, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
|
||||||
type prompt struct {
|
type prompt struct {
|
||||||
System string
|
System string
|
||||||
Prompt string
|
Prompt string
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPrompt(t *testing.T) {
|
func TestPrompt(t *testing.T) {
|
||||||
@ -61,7 +62,12 @@ func TestPrompt(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
got, err := Prompt(tc.template, tc.system, tc.prompt, tc.response, tc.generate)
|
tmpl, err := template.Parse(tc.template)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := Prompt(tmpl, tc.system, tc.prompt, tc.response, tc.generate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("error = %v", err)
|
t.Errorf("error = %v", err)
|
||||||
}
|
}
|
||||||
@ -192,7 +198,12 @@ func TestChatPrompt(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode)
|
tmpl, err := template.Parse(tc.template)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := ChatPrompt(tmpl, tc.messages, tc.window, encode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("error = %v", err)
|
t.Errorf("error = %v", err)
|
||||||
}
|
}
|
||||||
|
@ -31,6 +31,7 @@ import (
|
|||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/openai"
|
"github.com/ollama/ollama/openai"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
@ -121,8 +122,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if model.IsEmbedding() {
|
if !model.Has(CapabilityCompletion) {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support generate"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support generate", req.Model)})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -161,6 +162,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tmpl, err := template.Parse(req.Template)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
checkpointLoaded := time.Now()
|
checkpointLoaded := time.Now()
|
||||||
|
|
||||||
var prompt string
|
var prompt string
|
||||||
@ -169,7 +176,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
prompt = req.Prompt
|
prompt = req.Prompt
|
||||||
case req.Prompt != "":
|
case req.Prompt != "":
|
||||||
if req.Template == "" {
|
if req.Template == "" {
|
||||||
req.Template = model.Template
|
tmpl = model.Template
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.System == "" {
|
if req.System == "" {
|
||||||
@ -187,7 +194,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
|
|
||||||
sb.WriteString(req.Prompt)
|
sb.WriteString(req.Prompt)
|
||||||
|
|
||||||
p, err := Prompt(req.Template, req.System, sb.String(), "", true)
|
p, err := Prompt(tmpl, req.System, sb.String(), "", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
@ -242,7 +249,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
|
|
||||||
if !req.Raw {
|
if !req.Raw {
|
||||||
p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false)
|
p, err := Prompt(tmpl, req.System, req.Prompt, generated.String(), false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
@ -832,7 +839,10 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if req.Template != "" {
|
if req.Template != "" {
|
||||||
m.Template = req.Template
|
m.Template, err = template.Parse(req.Template)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
msgs := make([]api.Message, 0)
|
msgs := make([]api.Message, 0)
|
||||||
@ -853,7 +863,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|||||||
resp := &api.ShowResponse{
|
resp := &api.ShowResponse{
|
||||||
License: strings.Join(m.License, "\n"),
|
License: strings.Join(m.License, "\n"),
|
||||||
System: m.System,
|
System: m.System,
|
||||||
Template: m.Template,
|
Template: m.Template.String(),
|
||||||
Details: modelDetails,
|
Details: modelDetails,
|
||||||
Messages: msgs,
|
Messages: msgs,
|
||||||
ModifiedAt: manifest.fi.ModTime(),
|
ModifiedAt: manifest.fi.ModTime(),
|
||||||
@ -886,9 +896,48 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|||||||
fmt.Fprint(&sb, m.String())
|
fmt.Fprint(&sb, m.String())
|
||||||
resp.Modelfile = sb.String()
|
resp.Modelfile = sb.String()
|
||||||
|
|
||||||
|
kvData, err := getKVData(m.ModelPath, req.Verbose)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
delete(kvData, "general.name")
|
||||||
|
delete(kvData, "tokenizer.chat_template")
|
||||||
|
resp.ModelInfo = kvData
|
||||||
|
|
||||||
|
if len(m.ProjectorPaths) > 0 {
|
||||||
|
projectorData, err := getKVData(m.ProjectorPaths[0], req.Verbose)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resp.ProjectorInfo = projectorData
|
||||||
|
}
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getKVData(digest string, verbose bool) (llm.KV, error) {
|
||||||
|
maxArraySize := 0
|
||||||
|
if verbose {
|
||||||
|
maxArraySize = -1
|
||||||
|
}
|
||||||
|
kvData, err := llm.LoadModel(digest, maxArraySize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
kv := kvData.KV()
|
||||||
|
|
||||||
|
if !verbose {
|
||||||
|
for k := range kv {
|
||||||
|
if t, ok := kv[k].([]any); len(t) > 5 && ok {
|
||||||
|
kv[k] = []any{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) ListModelsHandler(c *gin.Context) {
|
func (s *Server) ListModelsHandler(c *gin.Context) {
|
||||||
ms, err := Manifests()
|
ms, err := Manifests()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -1153,7 +1202,10 @@ func (s *Server) GenerateRoutes() http.Handler {
|
|||||||
r.GET("/api/ps", s.ProcessHandler)
|
r.GET("/api/ps", s.ProcessHandler)
|
||||||
|
|
||||||
// Compatibility endpoints
|
// Compatibility endpoints
|
||||||
r.POST("/v1/chat/completions", openai.Middleware(), s.ChatHandler)
|
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
|
||||||
|
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
|
||||||
|
r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
|
||||||
|
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
|
||||||
|
|
||||||
for _, method := range []string{http.MethodGet, http.MethodHead} {
|
for _, method := range []string{http.MethodGet, http.MethodHead} {
|
||||||
r.Handle(method, "/", func(c *gin.Context) {
|
r.Handle(method, "/", func(c *gin.Context) {
|
||||||
@ -1219,11 +1271,20 @@ func Serve(ln net.Listener) error {
|
|||||||
schedCtx, schedDone := context.WithCancel(ctx)
|
schedCtx, schedDone := context.WithCancel(ctx)
|
||||||
sched := InitScheduler(schedCtx)
|
sched := InitScheduler(schedCtx)
|
||||||
s := &Server{addr: ln.Addr(), sched: sched}
|
s := &Server{addr: ln.Addr(), sched: sched}
|
||||||
r := s.GenerateRoutes()
|
|
||||||
|
http.Handle("/", s.GenerateRoutes())
|
||||||
|
|
||||||
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
|
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
|
||||||
srvr := &http.Server{
|
srvr := &http.Server{
|
||||||
Handler: r,
|
// Use http.DefaultServeMux so we get net/http/pprof for
|
||||||
|
// free.
|
||||||
|
//
|
||||||
|
// TODO(bmizerany): Decide if we want to make this
|
||||||
|
// configurable so it is not exposed by default, or allow
|
||||||
|
// users to bind it to a different port. This was a quick
|
||||||
|
// and easy way to get pprof, but it may not be the best
|
||||||
|
// way.
|
||||||
|
Handler: nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
// listen for a ctrl+c and stop any loaded llm
|
// listen for a ctrl+c and stop any loaded llm
|
||||||
@ -1342,11 +1403,16 @@ func (s *Server) ProcessHandler(c *gin.Context) {
|
|||||||
models = append(models, mr)
|
models = append(models, mr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
slices.SortStableFunc(models, func(i, j api.ProcessModelResponse) int {
|
||||||
|
// longest duration remaining listed first
|
||||||
|
return cmp.Compare(j.ExpiresAt.Unix(), i.ExpiresAt.Unix())
|
||||||
|
})
|
||||||
|
|
||||||
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
||||||
func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) {
|
func chatPrompt(ctx context.Context, runner *runnerRef, template *template.Template, messages []api.Message, numCtx int) (string, error) {
|
||||||
encode := func(s string) ([]int, error) {
|
encode := func(s string) ([]int, error) {
|
||||||
return runner.llama.Tokenize(ctx, s)
|
return runner.llama.Tokenize(ctx, s)
|
||||||
}
|
}
|
||||||
@ -1394,8 +1460,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if model.IsEmbedding() {
|
if !model.Has(CapabilityCompletion) {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support chat"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support chat", req.Model)})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -20,6 +20,8 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
|
"github.com/ollama/ollama/openai"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
@ -105,6 +107,24 @@ func Test_Routes(t *testing.T) {
|
|||||||
assert.Empty(t, len(modelList.Models))
|
assert.Empty(t, len(modelList.Models))
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "openai empty list",
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/v1/models",
|
||||||
|
Expected: func(t *testing.T, resp *http.Response) {
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
assert.Equal(t, "application/json", contentType)
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var modelList openai.ListCompletion
|
||||||
|
err = json.Unmarshal(body, &modelList)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "list", modelList.Object)
|
||||||
|
assert.Empty(t, modelList.Data)
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Name: "Tags Handler (yes tags)",
|
Name: "Tags Handler (yes tags)",
|
||||||
Method: http.MethodGet,
|
Method: http.MethodGet,
|
||||||
@ -128,6 +148,25 @@ func Test_Routes(t *testing.T) {
|
|||||||
assert.Equal(t, "test-model:latest", modelList.Models[0].Name)
|
assert.Equal(t, "test-model:latest", modelList.Models[0].Name)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "openai list models with tags",
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/v1/models",
|
||||||
|
Expected: func(t *testing.T, resp *http.Response) {
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
assert.Equal(t, "application/json", contentType)
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var modelList openai.ListCompletion
|
||||||
|
err = json.Unmarshal(body, &modelList)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, modelList.Data, 1)
|
||||||
|
assert.Equal(t, "test-model:latest", modelList.Data[0].Id)
|
||||||
|
assert.Equal(t, "library", modelList.Data[0].OwnedBy)
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Name: "Create Model Handler",
|
Name: "Create Model Handler",
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
@ -213,6 +252,25 @@ func Test_Routes(t *testing.T) {
|
|||||||
"top_p 0.9",
|
"top_p 0.9",
|
||||||
}
|
}
|
||||||
assert.Equal(t, expectedParams, params)
|
assert.Equal(t, expectedParams, params)
|
||||||
|
assert.InDelta(t, 0, showResp.ModelInfo["general.parameter_count"], 1e-9, "Parameter count should be 0")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "openai retrieve model handler",
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/v1/models/show-model",
|
||||||
|
Expected: func(t *testing.T, resp *http.Response) {
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
assert.Equal(t, "application/json", contentType)
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var retrieveResp api.RetrieveModelResponse
|
||||||
|
err = json.Unmarshal(body, &retrieveResp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "show-model", retrieveResp.Id)
|
||||||
|
assert.Equal(t, "library", retrieveResp.OwnedBy)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -327,6 +385,43 @@ func TestCase(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestShow(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||||
|
envconfig.LoadConfig()
|
||||||
|
|
||||||
|
var s Server
|
||||||
|
|
||||||
|
createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "show-model",
|
||||||
|
Modelfile: fmt.Sprintf(
|
||||||
|
"FROM %s\nFROM %s",
|
||||||
|
createBinFile(t, llm.KV{"general.architecture": "test"}, nil),
|
||||||
|
createBinFile(t, llm.KV{"general.architecture": "clip"}, nil),
|
||||||
|
),
|
||||||
|
})
|
||||||
|
|
||||||
|
w := createRequest(t, s.ShowModelHandler, api.ShowRequest{
|
||||||
|
Name: "show-model",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp api.ShowResponse
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.ModelInfo["general.architecture"] != "test" {
|
||||||
|
t.Fatal("Expected model architecture to be 'test', but got", resp.ModelInfo["general.architecture"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.ProjectorInfo["general.architecture"] != "clip" {
|
||||||
|
t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNormalize(t *testing.T) {
|
func TestNormalize(t *testing.T) {
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
input []float32
|
input []float32
|
||||||
|
110
server/sched.go
110
server/sched.go
@ -23,6 +23,7 @@ type LlmRequest struct {
|
|||||||
ctx context.Context //nolint:containedctx
|
ctx context.Context //nolint:containedctx
|
||||||
model *Model
|
model *Model
|
||||||
opts api.Options
|
opts api.Options
|
||||||
|
origNumCtx int // Track the initial ctx request
|
||||||
sessionDuration time.Duration
|
sessionDuration time.Duration
|
||||||
successCh chan *runnerRef
|
successCh chan *runnerRef
|
||||||
errCh chan error
|
errCh chan error
|
||||||
@ -38,13 +39,23 @@ type Scheduler struct {
|
|||||||
loaded map[string]*runnerRef
|
loaded map[string]*runnerRef
|
||||||
loadedMu sync.Mutex
|
loadedMu sync.Mutex
|
||||||
|
|
||||||
loadFn func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList)
|
loadFn func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int)
|
||||||
newServerFn func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error)
|
newServerFn func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error)
|
||||||
getGpuFn func() gpu.GpuInfoList
|
getGpuFn func() gpu.GpuInfoList
|
||||||
getCpuFn func() gpu.GpuInfoList
|
getCpuFn func() gpu.GpuInfoList
|
||||||
reschedDelay time.Duration
|
reschedDelay time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Default automatic value for number of models we allow per GPU
|
||||||
|
// Model will still need to fit in VRAM, but loading many small models
|
||||||
|
// on a large GPU can cause stalling
|
||||||
|
var defaultModelsPerGPU = 3
|
||||||
|
|
||||||
|
// Default automatic value for parallel setting
|
||||||
|
// Model will still need to fit in VRAM. If this setting wont fit
|
||||||
|
// we'll back off down to 1 to try to get it to fit
|
||||||
|
var defaultParallel = 4
|
||||||
|
|
||||||
var ErrMaxQueue = fmt.Errorf("server busy, please try again. maximum pending requests exceeded")
|
var ErrMaxQueue = fmt.Errorf("server busy, please try again. maximum pending requests exceeded")
|
||||||
|
|
||||||
func InitScheduler(ctx context.Context) *Scheduler {
|
func InitScheduler(ctx context.Context) *Scheduler {
|
||||||
@ -65,13 +76,10 @@ func InitScheduler(ctx context.Context) *Scheduler {
|
|||||||
|
|
||||||
// context must be canceled to decrement ref count and release the runner
|
// context must be canceled to decrement ref count and release the runner
|
||||||
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) {
|
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) {
|
||||||
// allocate a large enough kv cache for all parallel requests
|
|
||||||
if opts.NumCtx < 4 {
|
if opts.NumCtx < 4 {
|
||||||
opts.NumCtx = 4
|
opts.NumCtx = 4
|
||||||
}
|
}
|
||||||
|
|
||||||
opts.NumCtx *= envconfig.NumParallel
|
|
||||||
|
|
||||||
req := &LlmRequest{
|
req := &LlmRequest{
|
||||||
ctx: c,
|
ctx: c,
|
||||||
model: model,
|
model: model,
|
||||||
@ -110,11 +118,25 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
|||||||
case pending := <-s.pendingReqCh:
|
case pending := <-s.pendingReqCh:
|
||||||
// Block other requests until we get this pending request running
|
// Block other requests until we get this pending request running
|
||||||
pending.schedAttempts++
|
pending.schedAttempts++
|
||||||
|
if pending.origNumCtx == 0 {
|
||||||
|
pending.origNumCtx = pending.opts.NumCtx
|
||||||
|
}
|
||||||
|
|
||||||
if pending.ctx.Err() != nil {
|
if pending.ctx.Err() != nil {
|
||||||
slog.Debug("pending request cancelled or timed out, skipping scheduling")
|
slog.Debug("pending request cancelled or timed out, skipping scheduling")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
numParallel := envconfig.NumParallel
|
||||||
|
// TODO (jmorganca): multimodal models don't support parallel yet
|
||||||
|
// see https://github.com/ollama/ollama/issues/4165
|
||||||
|
if len(pending.model.ProjectorPaths) > 0 && numParallel != 1 {
|
||||||
|
numParallel = 1
|
||||||
|
slog.Warn("multimodal models don't support parallel requests yet")
|
||||||
|
}
|
||||||
|
// Keep NumCtx and numParallel in sync
|
||||||
|
if numParallel > 1 {
|
||||||
|
pending.opts.NumCtx = pending.origNumCtx * numParallel
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
var runnerToExpire *runnerRef
|
var runnerToExpire *runnerRef
|
||||||
@ -143,8 +165,28 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
|||||||
gpus = s.getGpuFn()
|
gpus = s.getGpuFn()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if envconfig.MaxRunners <= 0 {
|
||||||
|
// No user specified MaxRunners, so figure out what automatic setting to use
|
||||||
|
// If all GPUs have reliable free memory reporting, defaultModelsPerGPU * the number of GPUs
|
||||||
|
// if any GPU has unreliable free memory reporting, 1x the number of GPUs
|
||||||
|
allReliable := true
|
||||||
|
for _, gpu := range gpus {
|
||||||
|
if gpu.UnreliableFreeMemory {
|
||||||
|
allReliable = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if allReliable {
|
||||||
|
envconfig.MaxRunners = defaultModelsPerGPU * len(gpus)
|
||||||
|
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", envconfig.MaxRunners, "gpu_count", len(gpus))
|
||||||
|
} else {
|
||||||
|
slog.Info("one or more GPUs detected that are unable to accurately report free memory - disabling default concurrency")
|
||||||
|
envconfig.MaxRunners = len(gpus)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Load model for fitting
|
// Load model for fitting
|
||||||
ggml, err := llm.LoadModel(pending.model.ModelPath)
|
ggml, err := llm.LoadModel(pending.model.ModelPath, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pending.errCh <- err
|
pending.errCh <- err
|
||||||
break
|
break
|
||||||
@ -152,26 +194,32 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
|||||||
|
|
||||||
// Evaluate if the model will fit in the available system memory, or if we should unload a model first
|
// Evaluate if the model will fit in the available system memory, or if we should unload a model first
|
||||||
if len(gpus) == 1 && gpus[0].Library == "cpu" {
|
if len(gpus) == 1 && gpus[0].Library == "cpu" {
|
||||||
|
// simplifying assumption of defaultParallel when in CPU mode
|
||||||
|
if numParallel <= 0 {
|
||||||
|
numParallel = defaultParallel
|
||||||
|
pending.opts.NumCtx = pending.origNumCtx * numParallel
|
||||||
|
}
|
||||||
|
|
||||||
if loadedCount == 0 {
|
if loadedCount == 0 {
|
||||||
slog.Debug("cpu mode with first model, loading")
|
slog.Debug("cpu mode with first model, loading")
|
||||||
s.loadFn(pending, ggml, gpus)
|
s.loadFn(pending, ggml, gpus, numParallel)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
runnerToExpire = s.maybeFindCPURunnerToUnload(pending, ggml, gpus)
|
runnerToExpire = s.maybeFindCPURunnerToUnload(pending, ggml, gpus)
|
||||||
if runnerToExpire == nil {
|
if runnerToExpire == nil {
|
||||||
slog.Debug("cpu mode with available system memory or first model, loading")
|
slog.Debug("cpu mode with available system memory or first model, loading")
|
||||||
s.loadFn(pending, ggml, gpus)
|
s.loadFn(pending, ggml, gpus, numParallel)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
// else we need to expire a runner
|
// else we need to expire a runner
|
||||||
} else if loadedCount == 0 {
|
} else if loadedCount == 0 {
|
||||||
// No models loaded. Load the model but prefer the best fit.
|
// No models loaded. Load the model but prefer the best fit.
|
||||||
slog.Debug("loading first model", "model", pending.model.ModelPath)
|
slog.Debug("loading first model", "model", pending.model.ModelPath)
|
||||||
g := pickBestFitGPUs(pending, ggml, gpus)
|
g := pickBestFitGPUs(pending, ggml, gpus, &numParallel)
|
||||||
if g != nil {
|
if g != nil {
|
||||||
gpus = g
|
gpus = g
|
||||||
}
|
}
|
||||||
s.loadFn(pending, ggml, gpus)
|
s.loadFn(pending, ggml, gpus, numParallel)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -186,10 +234,10 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
|||||||
|
|
||||||
// Update free memory from currently loaded models
|
// Update free memory from currently loaded models
|
||||||
s.updateFreeSpace(availGpus)
|
s.updateFreeSpace(availGpus)
|
||||||
fitGpus := pickBestFitGPUs(pending, ggml, availGpus)
|
fitGpus := pickBestFitGPUs(pending, ggml, availGpus, &numParallel)
|
||||||
if fitGpus != nil {
|
if fitGpus != nil {
|
||||||
slog.Debug("new model fits with existing models, loading")
|
slog.Debug("new model fits with existing models, loading")
|
||||||
s.loadFn(pending, ggml, fitGpus)
|
s.loadFn(pending, ggml, fitGpus, numParallel)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -350,8 +398,11 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) {
|
func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
|
||||||
llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts)
|
if numParallel < 1 {
|
||||||
|
numParallel = 1
|
||||||
|
}
|
||||||
|
llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// some older models are not compatible with newer versions of llama.cpp
|
// some older models are not compatible with newer versions of llama.cpp
|
||||||
// show a generalized compatibility error until there is a better way to
|
// show a generalized compatibility error until there is a better way to
|
||||||
@ -375,6 +426,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList)
|
|||||||
loading: true,
|
loading: true,
|
||||||
refCount: 1,
|
refCount: 1,
|
||||||
}
|
}
|
||||||
|
runner.numParallel = numParallel
|
||||||
runner.refMu.Lock()
|
runner.refMu.Lock()
|
||||||
|
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
@ -485,6 +537,7 @@ type runnerRef struct {
|
|||||||
|
|
||||||
model *Model
|
model *Model
|
||||||
modelPath string
|
modelPath string
|
||||||
|
numParallel int
|
||||||
*api.Options
|
*api.Options
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -525,6 +578,9 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
|||||||
optsNew.NumGPU = -1
|
optsNew.NumGPU = -1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Normalize the NumCtx for parallelism
|
||||||
|
optsExisting.NumCtx = optsExisting.NumCtx / runner.numParallel
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
|
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
|
||||||
@ -611,36 +667,56 @@ func (a ByDuration) Less(i, j int) bool {
|
|||||||
|
|
||||||
// pickBestFitGPUs will try to find the optimal placement of the model in the available GPUs where the model fully fits
|
// pickBestFitGPUs will try to find the optimal placement of the model in the available GPUs where the model fully fits
|
||||||
// If the model can not be fit fully within the available GPU(s) nil is returned
|
// If the model can not be fit fully within the available GPU(s) nil is returned
|
||||||
func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) gpu.GpuInfoList {
|
// If numParallel is <= 0, this will attempt try to optimize parallism based on available VRAM, and adjust
|
||||||
|
// opts.NumCtx accordingly
|
||||||
|
func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel *int) gpu.GpuInfoList {
|
||||||
var estimatedVRAM uint64
|
var estimatedVRAM uint64
|
||||||
|
|
||||||
|
var numParallelToTry []int
|
||||||
|
if *numParallel <= 0 {
|
||||||
|
// If no specific parallel setting was provided, try larger then smaller, always end with 1
|
||||||
|
numParallelToTry = append(numParallelToTry, defaultParallel, 1)
|
||||||
|
} else {
|
||||||
|
numParallelToTry = []int{*numParallel}
|
||||||
|
}
|
||||||
|
|
||||||
for _, gl := range gpus.ByLibrary() {
|
for _, gl := range gpus.ByLibrary() {
|
||||||
var ok bool
|
var ok bool
|
||||||
sgl := append(make(gpu.GpuInfoList, 0, len(gl)), gl...)
|
sgl := append(make(gpu.GpuInfoList, 0, len(gl)), gl...)
|
||||||
|
|
||||||
// TODO - potentially sort by performance capability, existing models loaded, etc.
|
// TODO - potentially sort by performance capability, existing models loaded, etc.
|
||||||
|
// TODO - Eliminate any GPUs that already have envconfig.MaxRunners loaded on them
|
||||||
// Note: at present, this will favor more VRAM over faster GPU speed in mixed setups
|
// Note: at present, this will favor more VRAM over faster GPU speed in mixed setups
|
||||||
sort.Sort(sort.Reverse(gpu.ByFreeMemory(sgl)))
|
sort.Sort(sort.Reverse(gpu.ByFreeMemory(sgl)))
|
||||||
|
|
||||||
// First attempt to fit the model into a single GPU
|
// First attempt to fit the model into a single GPU
|
||||||
|
for _, p := range numParallelToTry {
|
||||||
|
req.opts.NumCtx = req.origNumCtx * p
|
||||||
if !envconfig.SchedSpread {
|
if !envconfig.SchedSpread {
|
||||||
for _, g := range sgl {
|
for _, g := range sgl {
|
||||||
if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
|
if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
|
||||||
slog.Debug("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
|
slog.Info("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "parallel", p, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
|
||||||
|
*numParallel = p
|
||||||
return []gpu.GpuInfo{g}
|
return []gpu.GpuInfo{g}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TODO future refinements
|
// TODO future refinements
|
||||||
// - if multiple Libraries, see if any single GPU in any Library will fit
|
// - if multiple Libraries, see if any single GPU in any Library will fit
|
||||||
// - try subsets of GPUs instead of just falling back to 1 or all in a family
|
// - try subsets of GPUs instead of just falling back to 1 or all in a family
|
||||||
|
|
||||||
// Now try all the GPUs
|
// Now try all the GPUs
|
||||||
|
for _, p := range numParallelToTry {
|
||||||
|
req.opts.NumCtx = req.origNumCtx * p
|
||||||
if ok, estimatedVRAM = llm.PredictServerFit(sgl, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
|
if ok, estimatedVRAM = llm.PredictServerFit(sgl, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
|
||||||
slog.Debug("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "required", format.HumanBytes2(estimatedVRAM))
|
slog.Info("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "parallel", p, "required", format.HumanBytes2(estimatedVRAM))
|
||||||
|
*numParallel = p
|
||||||
return sgl
|
return sgl
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -47,11 +47,11 @@ func TestLoad(t *testing.T) {
|
|||||||
sessionDuration: 2,
|
sessionDuration: 2,
|
||||||
}
|
}
|
||||||
// Fail to load model first
|
// Fail to load model first
|
||||||
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
|
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||||
return nil, fmt.Errorf("something failed to load model blah")
|
return nil, fmt.Errorf("something failed to load model blah")
|
||||||
}
|
}
|
||||||
gpus := gpu.GpuInfoList{}
|
gpus := gpu.GpuInfoList{}
|
||||||
s.load(req, ggml, gpus)
|
s.load(req, ggml, gpus, 0)
|
||||||
require.Empty(t, req.successCh)
|
require.Empty(t, req.successCh)
|
||||||
require.Len(t, req.errCh, 1)
|
require.Len(t, req.errCh, 1)
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
@ -61,10 +61,10 @@ func TestLoad(t *testing.T) {
|
|||||||
require.Contains(t, err.Error(), "this model may be incompatible")
|
require.Contains(t, err.Error(), "this model may be incompatible")
|
||||||
|
|
||||||
server := &mockLlm{estimatedVRAM: 10, estimatedVRAMByGPU: map[string]uint64{}}
|
server := &mockLlm{estimatedVRAM: 10, estimatedVRAMByGPU: map[string]uint64{}}
|
||||||
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
|
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||||
return server, nil
|
return server, nil
|
||||||
}
|
}
|
||||||
s.load(req, ggml, gpus)
|
s.load(req, ggml, gpus, 0)
|
||||||
select {
|
select {
|
||||||
case err := <-req.errCh:
|
case err := <-req.errCh:
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -78,12 +78,12 @@ func TestLoad(t *testing.T) {
|
|||||||
|
|
||||||
req.model.ModelPath = "dummy_model_path"
|
req.model.ModelPath = "dummy_model_path"
|
||||||
server.waitResp = fmt.Errorf("wait failure")
|
server.waitResp = fmt.Errorf("wait failure")
|
||||||
s.load(req, ggml, gpus)
|
s.load(req, ggml, gpus, 0)
|
||||||
select {
|
select {
|
||||||
case err := <-req.errCh:
|
case err := <-req.errCh:
|
||||||
require.Contains(t, err.Error(), "wait failure")
|
require.Contains(t, err.Error(), "wait failure")
|
||||||
case resp := <-req.successCh:
|
case resp := <-req.successCh:
|
||||||
t.Errorf("unexpected success %v", resp)
|
t.Fatalf("unexpected success %v", resp)
|
||||||
}
|
}
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
runner := s.loaded["dummy_model_path"]
|
runner := s.loaded["dummy_model_path"]
|
||||||
@ -102,7 +102,7 @@ type bundle struct {
|
|||||||
ggml *llm.GGML
|
ggml *llm.GGML
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
|
func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||||
return scenario.srv, nil
|
return scenario.srv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -128,14 +128,14 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
|
|||||||
"tokenizer.ggml.scores": []float32{0},
|
"tokenizer.ggml.scores": []float32{0},
|
||||||
"tokenizer.ggml.token_type": []int32{0},
|
"tokenizer.ggml.token_type": []int32{0},
|
||||||
}, []llm.Tensor{
|
}, []llm.Tensor{
|
||||||
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
|
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
||||||
{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
|
{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
fname := f.Name()
|
fname := f.Name()
|
||||||
model := &Model{Name: modelName, ModelPath: fname}
|
model := &Model{Name: modelName, ModelPath: fname}
|
||||||
scenario.ggml, err = llm.LoadModel(model.ModelPath)
|
scenario.ggml, err = llm.LoadModel(model.ModelPath, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
scenario.req = &LlmRequest{
|
scenario.req = &LlmRequest{
|
||||||
@ -200,7 +200,7 @@ func TestRequests(t *testing.T) {
|
|||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario1a.req.errCh)
|
require.Empty(t, scenario1a.req.errCh)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Errorf("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same runner as first request due to not needing a reload
|
// Same runner as first request due to not needing a reload
|
||||||
@ -213,7 +213,7 @@ func TestRequests(t *testing.T) {
|
|||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario1b.req.errCh)
|
require.Empty(t, scenario1b.req.errCh)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Errorf("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Trigger a reload
|
// Trigger a reload
|
||||||
@ -231,7 +231,7 @@ func TestRequests(t *testing.T) {
|
|||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario2a.req.errCh)
|
require.Empty(t, scenario2a.req.errCh)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Errorf("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
|
|
||||||
envconfig.MaxRunners = 1
|
envconfig.MaxRunners = 1
|
||||||
@ -247,7 +247,7 @@ func TestRequests(t *testing.T) {
|
|||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario3a.req.errCh)
|
require.Empty(t, scenario3a.req.errCh)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Errorf("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Len(t, s.loaded, 1)
|
require.Len(t, s.loaded, 1)
|
||||||
@ -263,7 +263,7 @@ func TestRequests(t *testing.T) {
|
|||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario3b.req.errCh)
|
require.Empty(t, scenario3b.req.errCh)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Errorf("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Len(t, s.loaded, 2)
|
require.Len(t, s.loaded, 2)
|
||||||
@ -279,7 +279,7 @@ func TestRequests(t *testing.T) {
|
|||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario3c.req.errCh)
|
require.Empty(t, scenario3c.req.errCh)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Errorf("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Len(t, s.loaded, 3)
|
require.Len(t, s.loaded, 3)
|
||||||
@ -306,7 +306,7 @@ func TestRequests(t *testing.T) {
|
|||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario3d.req.errCh)
|
require.Empty(t, scenario3d.req.errCh)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Errorf("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Len(t, s.loaded, 2)
|
require.Len(t, s.loaded, 2)
|
||||||
@ -349,7 +349,7 @@ func TestGetRunner(t *testing.T) {
|
|||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, errCh1a)
|
require.Empty(t, errCh1a)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Errorf("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
scenario1a.ctxDone()
|
scenario1a.ctxDone()
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
@ -400,7 +400,7 @@ func TestPrematureExpired(t *testing.T) {
|
|||||||
slog.Info("sending premature expired event now")
|
slog.Info("sending premature expired event now")
|
||||||
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
|
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Errorf("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
time.Sleep(scenario1a.req.sessionDuration)
|
time.Sleep(scenario1a.req.sessionDuration)
|
||||||
scenario1a.ctxDone()
|
scenario1a.ctxDone()
|
||||||
@ -427,7 +427,7 @@ func TestUseLoadedRunner(t *testing.T) {
|
|||||||
}
|
}
|
||||||
finished := make(chan *LlmRequest)
|
finished := make(chan *LlmRequest)
|
||||||
llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
|
llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
|
||||||
r1 := &runnerRef{llama: llm1, sessionDuration: 1}
|
r1 := &runnerRef{llama: llm1, sessionDuration: 1, numParallel: 1}
|
||||||
req.useLoadedRunner(r1, finished)
|
req.useLoadedRunner(r1, finished)
|
||||||
require.Equal(t, uint(1), r1.refCount)
|
require.Equal(t, uint(1), r1.refCount)
|
||||||
require.Equal(t, time.Duration(2), r1.sessionDuration)
|
require.Equal(t, time.Duration(2), r1.sessionDuration)
|
||||||
@ -435,7 +435,7 @@ func TestUseLoadedRunner(t *testing.T) {
|
|||||||
case success := <-req.successCh:
|
case success := <-req.successCh:
|
||||||
require.Equal(t, r1, success)
|
require.Equal(t, r1, success)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Errorf("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
done()
|
done()
|
||||||
fin := <-finished
|
fin := <-finished
|
||||||
@ -461,8 +461,8 @@ func TestUpdateFreeSpace(t *testing.T) {
|
|||||||
gpus[1].FreeMemory = 1900
|
gpus[1].FreeMemory = 1900
|
||||||
llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{"1": 50, "2": 50}}
|
llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{"1": 50, "2": 50}}
|
||||||
llm2 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{"1": 125, "2": 75}}
|
llm2 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{"1": 125, "2": 75}}
|
||||||
r1 := &runnerRef{llama: llm1, gpus: gpus}
|
r1 := &runnerRef{llama: llm1, gpus: gpus, numParallel: 1}
|
||||||
r2 := &runnerRef{llama: llm2, gpus: gpus}
|
r2 := &runnerRef{llama: llm2, gpus: gpus, numParallel: 1}
|
||||||
|
|
||||||
s := InitScheduler(ctx)
|
s := InitScheduler(ctx)
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
@ -513,8 +513,8 @@ func TestFindRunnerToUnload(t *testing.T) {
|
|||||||
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
defer done()
|
defer done()
|
||||||
|
|
||||||
r1 := &runnerRef{refCount: 1, sessionDuration: 1}
|
r1 := &runnerRef{refCount: 1, sessionDuration: 1, numParallel: 1}
|
||||||
r2 := &runnerRef{sessionDuration: 2}
|
r2 := &runnerRef{sessionDuration: 2, numParallel: 1}
|
||||||
|
|
||||||
s := InitScheduler(ctx)
|
s := InitScheduler(ctx)
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
@ -536,9 +536,13 @@ func TestNeedsReload(t *testing.T) {
|
|||||||
llm := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
|
llm := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
|
||||||
do := api.DefaultOptions()
|
do := api.DefaultOptions()
|
||||||
runner := &runnerRef{
|
runner := &runnerRef{
|
||||||
model: &Model{AdapterPaths: []string{"adapter1"}, ProjectorPaths: []string{"projector1"}},
|
model: &Model{
|
||||||
|
AdapterPaths: []string{"adapter1"},
|
||||||
|
ProjectorPaths: []string{"projector1"},
|
||||||
|
},
|
||||||
Options: &do,
|
Options: &do,
|
||||||
llama: llm,
|
llama: llm,
|
||||||
|
numParallel: 1,
|
||||||
}
|
}
|
||||||
req := &LlmRequest{
|
req := &LlmRequest{
|
||||||
model: &Model{
|
model: &Model{
|
||||||
@ -581,8 +585,8 @@ func TestUnloadAllRunners(t *testing.T) {
|
|||||||
s := InitScheduler(ctx)
|
s := InitScheduler(ctx)
|
||||||
s.unloadAllRunners()
|
s.unloadAllRunners()
|
||||||
|
|
||||||
r1 := &runnerRef{llama: llm1}
|
r1 := &runnerRef{llama: llm1, numParallel: 1}
|
||||||
r2 := &runnerRef{llama: llm2}
|
r2 := &runnerRef{llama: llm2, numParallel: 1}
|
||||||
|
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
s.loaded["a"] = r1
|
s.loaded["a"] = r1
|
||||||
@ -596,14 +600,32 @@ func TestUnloadAllRunners(t *testing.T) {
|
|||||||
|
|
||||||
func TestUnload(t *testing.T) {
|
func TestUnload(t *testing.T) {
|
||||||
llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
|
llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
|
||||||
r1 := &runnerRef{llama: llm1}
|
r1 := &runnerRef{llama: llm1, numParallel: 1}
|
||||||
r2 := &runnerRef{model: &Model{AdapterPaths: []string{"A"}}}
|
r2 := &runnerRef{model: &Model{AdapterPaths: []string{"A"}}, numParallel: 1}
|
||||||
r1.unload()
|
r1.unload()
|
||||||
require.True(t, llm1.closeCalled)
|
require.True(t, llm1.closeCalled)
|
||||||
r2.unload()
|
r2.unload()
|
||||||
require.Nil(t, r2.model)
|
require.Nil(t, r2.model)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAlreadyCanceled(t *testing.T) {
|
||||||
|
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
defer done()
|
||||||
|
dctx, done2 := context.WithCancel(ctx)
|
||||||
|
done2()
|
||||||
|
scenario1a := newScenario(t, dctx, "ollama-model-1", 10)
|
||||||
|
scenario1a.req.sessionDuration = 0
|
||||||
|
s := InitScheduler(ctx)
|
||||||
|
slog.Info("scenario1a")
|
||||||
|
s.pendingReqCh <- scenario1a.req
|
||||||
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
|
s.Run(ctx)
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
require.Empty(t, s.pendingReqCh)
|
||||||
|
require.Empty(t, scenario1a.req.errCh)
|
||||||
|
require.Empty(t, scenario1a.req.successCh)
|
||||||
|
}
|
||||||
|
|
||||||
type mockLlm struct {
|
type mockLlm struct {
|
||||||
pingResp error
|
pingResp error
|
||||||
waitResp error
|
waitResp error
|
||||||
|
158
template/template.go
Normal file
158
template/template.go
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
package template
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"embed"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"text/template"
|
||||||
|
"text/template/parse"
|
||||||
|
|
||||||
|
"github.com/agnivade/levenshtein"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed index.json
|
||||||
|
var indexBytes []byte
|
||||||
|
|
||||||
|
//go:embed *.gotmpl
|
||||||
|
var templatesFS embed.FS
|
||||||
|
|
||||||
|
var templatesOnce = sync.OnceValues(func() ([]*named, error) {
|
||||||
|
var templates []*named
|
||||||
|
if err := json.Unmarshal(indexBytes, &templates); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, t := range templates {
|
||||||
|
bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalize line endings
|
||||||
|
t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
return templates, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
type named struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Template string `json:"template"`
|
||||||
|
Bytes []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t named) Reader() io.Reader {
|
||||||
|
return bytes.NewReader(t.Bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Named(s string) (*named, error) {
|
||||||
|
templates, err := templatesOnce()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var template *named
|
||||||
|
score := math.MaxInt
|
||||||
|
for _, t := range templates {
|
||||||
|
if s := levenshtein.ComputeDistance(s, t.Template); s < score {
|
||||||
|
score = s
|
||||||
|
template = t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if score < 100 {
|
||||||
|
return template, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errors.New("no matching template found")
|
||||||
|
}
|
||||||
|
|
||||||
|
type Template struct {
|
||||||
|
*template.Template
|
||||||
|
raw string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Template) String() string {
|
||||||
|
return t.raw
|
||||||
|
}
|
||||||
|
|
||||||
|
var DefaultTemplate, _ = Parse("{{ .Prompt }}")
|
||||||
|
|
||||||
|
func Parse(s string) (*Template, error) {
|
||||||
|
t, err := template.New("").Option("missingkey=zero").Parse(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Template{Template: t, raw: s}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Template) Vars() []string {
|
||||||
|
var vars []string
|
||||||
|
for _, n := range t.Tree.Root.Nodes {
|
||||||
|
vars = append(vars, parseNode(n)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
set := make(map[string]struct{})
|
||||||
|
for _, n := range vars {
|
||||||
|
set[strings.ToLower(n)] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
vars = maps.Keys(set)
|
||||||
|
slices.Sort(vars)
|
||||||
|
return vars
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseNode(n parse.Node) []string {
|
||||||
|
switch n := n.(type) {
|
||||||
|
case *parse.ActionNode:
|
||||||
|
return parseNode(n.Pipe)
|
||||||
|
case *parse.IfNode:
|
||||||
|
names := parseNode(n.Pipe)
|
||||||
|
names = append(names, parseNode(n.List)...)
|
||||||
|
if n.ElseList != nil {
|
||||||
|
names = append(names, parseNode(n.ElseList)...)
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
case *parse.RangeNode:
|
||||||
|
names := parseNode(n.Pipe)
|
||||||
|
names = append(names, parseNode(n.List)...)
|
||||||
|
if n.ElseList != nil {
|
||||||
|
names = append(names, parseNode(n.ElseList)...)
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
case *parse.WithNode:
|
||||||
|
names := parseNode(n.Pipe)
|
||||||
|
names = append(names, parseNode(n.List)...)
|
||||||
|
if n.ElseList != nil {
|
||||||
|
names = append(names, parseNode(n.ElseList)...)
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
case *parse.PipeNode:
|
||||||
|
var names []string
|
||||||
|
for _, c := range n.Cmds {
|
||||||
|
for _, a := range c.Args {
|
||||||
|
names = append(names, parseNode(a)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
case *parse.ListNode:
|
||||||
|
var names []string
|
||||||
|
for _, n := range n.Nodes {
|
||||||
|
names = append(names, parseNode(n)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return names
|
||||||
|
case *parse.FieldNode:
|
||||||
|
return n.Ident
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
89
template/template_test.go
Normal file
89
template/template_test.go
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
package template
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
"text/template"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNamed(t *testing.T) {
|
||||||
|
f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(f)
|
||||||
|
for scanner.Scan() {
|
||||||
|
var ss map[string]string
|
||||||
|
if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range ss {
|
||||||
|
t.Run(k, func(t *testing.T) {
|
||||||
|
kv := llm.KV{"tokenizer.chat_template": v}
|
||||||
|
s := kv.ChatTemplate()
|
||||||
|
r, err := Named(s)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Name != k {
|
||||||
|
t.Errorf("expected %q, got %q", k, r.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := io.Copy(&b, r.Reader()); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpl, err := template.New(s).Parse(b.String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tmpl.Tree.Root.String() == "" {
|
||||||
|
t.Errorf("empty %s template", k)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParse(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
template string
|
||||||
|
vars []string
|
||||||
|
}{
|
||||||
|
{"{{ .Prompt }}", []string{"prompt"}},
|
||||||
|
{"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}},
|
||||||
|
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
|
||||||
|
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "system", "tools"}},
|
||||||
|
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
|
||||||
|
{"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}},
|
||||||
|
{"{{ .Prompt }} {{ .Suffix }}", []string{"prompt", "suffix"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run("", func(t *testing.T) {
|
||||||
|
tmpl, err := Parse(tt.template)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
vars := tmpl.Vars()
|
||||||
|
if !slices.Equal(tt.vars, vars) {
|
||||||
|
t.Errorf("expected %v, got %v", tt.vars, vars)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -1,70 +0,0 @@
|
|||||||
package templates
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"embed"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"math"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/agnivade/levenshtein"
|
|
||||||
)
|
|
||||||
|
|
||||||
//go:embed index.json
|
|
||||||
var indexBytes []byte
|
|
||||||
|
|
||||||
//go:embed *.gotmpl
|
|
||||||
var templatesFS embed.FS
|
|
||||||
|
|
||||||
var templatesOnce = sync.OnceValues(func() ([]*Template, error) {
|
|
||||||
var templates []*Template
|
|
||||||
if err := json.Unmarshal(indexBytes, &templates); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, t := range templates {
|
|
||||||
bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// normalize line endings
|
|
||||||
t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
|
|
||||||
}
|
|
||||||
|
|
||||||
return templates, nil
|
|
||||||
})
|
|
||||||
|
|
||||||
type Template struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Template string `json:"template"`
|
|
||||||
Bytes []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t Template) Reader() io.Reader {
|
|
||||||
return bytes.NewReader(t.Bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NamedTemplate(s string) (*Template, error) {
|
|
||||||
templates, err := templatesOnce()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var template *Template
|
|
||||||
score := math.MaxInt
|
|
||||||
for _, t := range templates {
|
|
||||||
if s := levenshtein.ComputeDistance(s, t.Template); s < score {
|
|
||||||
score = s
|
|
||||||
template = t
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if score < 100 {
|
|
||||||
return template, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errors.New("no matching template found")
|
|
||||||
}
|
|
@ -1,59 +0,0 @@
|
|||||||
package templates
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
"text/template"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/llm"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestKVChatTemplate(t *testing.T) {
|
|
||||||
f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
scanner := bufio.NewScanner(f)
|
|
||||||
for scanner.Scan() {
|
|
||||||
var ss map[string]string
|
|
||||||
if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v := range ss {
|
|
||||||
t.Run(k, func(t *testing.T) {
|
|
||||||
kv := llm.KV{"tokenizer.chat_template": v}
|
|
||||||
s := kv.ChatTemplate()
|
|
||||||
r, err := NamedTemplate(s)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.Name != k {
|
|
||||||
t.Errorf("expected %q, got %q", k, r.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
var b bytes.Buffer
|
|
||||||
if _, err := io.Copy(&b, r.Reader()); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tmpl, err := template.New(s).Parse(b.String())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if tmpl.Tree.Root.String() == "" {
|
|
||||||
t.Errorf("empty %s template", k)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -4,7 +4,6 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"cmp"
|
"cmp"
|
||||||
"encoding/hex"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
@ -371,57 +370,3 @@ func cutPromised(s, sep string) (before, after string, ok bool) {
|
|||||||
}
|
}
|
||||||
return cmp.Or(before, MissingPart), cmp.Or(after, MissingPart), true
|
return cmp.Or(before, MissingPart), cmp.Or(after, MissingPart), true
|
||||||
}
|
}
|
||||||
|
|
||||||
type DigestType byte
|
|
||||||
|
|
||||||
const (
|
|
||||||
DigestTypeInvalid DigestType = iota
|
|
||||||
DigestTypeSHA256
|
|
||||||
)
|
|
||||||
|
|
||||||
func (t DigestType) String() string {
|
|
||||||
switch t {
|
|
||||||
case DigestTypeSHA256:
|
|
||||||
return "sha256"
|
|
||||||
default:
|
|
||||||
return "invalid"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type Digest struct {
|
|
||||||
Type DigestType
|
|
||||||
Sum [32]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func ParseDigest(s string) (Digest, error) {
|
|
||||||
i := strings.IndexAny(s, "-:")
|
|
||||||
if i < 0 {
|
|
||||||
return Digest{}, fmt.Errorf("invalid digest %q", s)
|
|
||||||
}
|
|
||||||
typ, encSum := s[:i], s[i+1:]
|
|
||||||
if typ != "sha256" {
|
|
||||||
return Digest{}, fmt.Errorf("unsupported digest type %q", typ)
|
|
||||||
}
|
|
||||||
d := Digest{
|
|
||||||
Type: DigestTypeSHA256,
|
|
||||||
}
|
|
||||||
n, err := hex.Decode(d.Sum[:], []byte(encSum))
|
|
||||||
if err != nil {
|
|
||||||
return Digest{}, err
|
|
||||||
}
|
|
||||||
if n != 32 {
|
|
||||||
return Digest{}, fmt.Errorf("digest %q decoded to %d bytes; want 32", encSum, n)
|
|
||||||
}
|
|
||||||
return d, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d Digest) String() string {
|
|
||||||
if d.Type == DigestTypeInvalid {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("sha256-%x", d.Sum)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d Digest) IsValid() bool {
|
|
||||||
return d.Type != DigestTypeInvalid
|
|
||||||
}
|
|
||||||
|
@ -284,40 +284,6 @@ func TestFilepathAllocs(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
|
||||||
validSha256 = "sha256-1000000000000000000000000000000000000000000000000000000000000000"
|
|
||||||
validSha256Old = "sha256:1000000000000000000000000000000000000000000000000000000000000000"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestParseDigest(t *testing.T) {
|
|
||||||
cases := []struct {
|
|
||||||
in string
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{"", ""}, // empty
|
|
||||||
{"sha123-12", ""}, // invalid type
|
|
||||||
{"sha256-", ""}, // invalid sum
|
|
||||||
{"sha256-123", ""}, // invalid odd length sum
|
|
||||||
|
|
||||||
{validSha256, validSha256},
|
|
||||||
{validSha256Old, validSha256},
|
|
||||||
}
|
|
||||||
for _, tt := range cases {
|
|
||||||
t.Run(tt.in, func(t *testing.T) {
|
|
||||||
got, err := ParseDigest(tt.in)
|
|
||||||
if err != nil {
|
|
||||||
if tt.want != "" {
|
|
||||||
t.Errorf("parseDigest(%q) = %v; want %v", tt.in, err, tt.want)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if got.String() != tt.want {
|
|
||||||
t.Errorf("parseDigest(%q).String() = %q; want %q", tt.in, got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseNameFromFilepath(t *testing.T) {
|
func TestParseNameFromFilepath(t *testing.T) {
|
||||||
cases := map[string]Name{
|
cases := map[string]Name{
|
||||||
filepath.Join("host", "namespace", "model", "tag"): {Host: "host", Namespace: "namespace", Model: "model", Tag: "tag"},
|
filepath.Join("host", "namespace", "model", "tag"): {Host: "host", Namespace: "namespace", Model: "model", Tag: "tag"},
|
||||||
|
34
util/bufioutil/buffer_seeker.go
Normal file
34
util/bufioutil/buffer_seeker.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package bufioutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BufferedSeeker struct {
|
||||||
|
rs io.ReadSeeker
|
||||||
|
br *bufio.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBufferedSeeker(rs io.ReadSeeker, size int) *BufferedSeeker {
|
||||||
|
return &BufferedSeeker{
|
||||||
|
rs: rs,
|
||||||
|
br: bufio.NewReaderSize(rs, size),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BufferedSeeker) Read(p []byte) (int, error) {
|
||||||
|
return b.br.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BufferedSeeker) Seek(offset int64, whence int) (int64, error) {
|
||||||
|
if whence == io.SeekCurrent {
|
||||||
|
offset -= int64(b.br.Buffered())
|
||||||
|
}
|
||||||
|
n, err := b.rs.Seek(offset, whence)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
b.br.Reset(b.rs)
|
||||||
|
return n, nil
|
||||||
|
}
|
64
util/bufioutil/buffer_seeker_test.go
Normal file
64
util/bufioutil/buffer_seeker_test.go
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
package bufioutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBufferedSeeker(t *testing.T) {
|
||||||
|
const alphabet = "abcdefghijklmnopqrstuvwxyz"
|
||||||
|
|
||||||
|
bs := NewBufferedSeeker(strings.NewReader(alphabet), 0) // minReadBufferSize = 16
|
||||||
|
|
||||||
|
checkRead := func(buf []byte, expected string) {
|
||||||
|
t.Helper()
|
||||||
|
_, err := bs.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(buf, []byte(expected)) {
|
||||||
|
t.Fatalf("expected %s, got %s", expected, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the first 5 bytes
|
||||||
|
buf := make([]byte, 5)
|
||||||
|
|
||||||
|
checkRead(buf, "abcde")
|
||||||
|
|
||||||
|
// Seek back to the beginning
|
||||||
|
_, err := bs.Seek(0, io.SeekStart)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// read 'a'
|
||||||
|
checkRead(buf[:1], "a")
|
||||||
|
|
||||||
|
if bs.br.Buffered() == 0 {
|
||||||
|
t.Fatalf("totally unexpected sanity check failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Seek past 'b'
|
||||||
|
_, err = bs.Seek(1, io.SeekCurrent)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
checkRead(buf, "cdefg")
|
||||||
|
|
||||||
|
// Seek back to the beginning
|
||||||
|
_, err = bs.Seek(0, io.SeekStart)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
checkRead(buf, "abcde")
|
||||||
|
|
||||||
|
// Seek to the end
|
||||||
|
_, err = bs.Seek(-5, io.SeekEnd)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
checkRead(buf, "vwxyz")
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user