From b2799f111b6b662690ec3f705f4ab95a7f5d7df4 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Sat, 15 Jun 2024 13:17:20 -0700 Subject: [PATCH 01/54] Move libraries out of users path We update the PATH on windows to get the CLI mapped, but this has an unintended side effect of causing other apps that may use our bundled DLLs to get terminated when we upgrade. --- app/ollama.iss | 7 ++++++- gpu/gpu.go | 10 ++++++++-- llm/generate/gen_windows.ps1 | 32 ++++++++++++++++++-------------- llm/payload.go | 4 ++-- llm/server.go | 4 ++-- scripts/build_windows.ps1 | 10 +++++----- 6 files changed, 41 insertions(+), 26 deletions(-) diff --git a/app/ollama.iss b/app/ollama.iss index 9dc61abbf..e6502abd3 100644 --- a/app/ollama.iss +++ b/app/ollama.iss @@ -88,10 +88,15 @@ DialogFontSize=12 [Files] Source: ".\app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ; Flags: ignoreversion 64bit Source: "..\ollama.exe"; DestDir: "{app}"; Flags: ignoreversion 64bit -Source: "..\dist\windows-{#ARCH}\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit Source: "..\dist\windows-{#ARCH}\ollama_runners\*"; DestDir: "{app}\ollama_runners"; Flags: ignoreversion 64bit recursesubdirs Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion Source: ".\assets\app.ico"; DestDir: "{app}"; Flags: ignoreversion +#if DirExists("..\dist\windows-amd64\cuda") + Source: "..\dist\windows-amd64\cuda\*"; DestDir: "{app}\cuda\"; Flags: ignoreversion recursesubdirs +#endif +#if DirExists("..\dist\windows-amd64\oneapi") + Source: "..\dist\windows-amd64\oneapi\*"; DestDir: "{app}\oneapi\"; Flags: ignoreversion recursesubdirs +#endif #if DirExists("..\dist\windows-amd64\rocm") Source: "..\dist\windows-amd64\rocm\*"; DestDir: "{app}\rocm\"; Flags: ignoreversion recursesubdirs #endif diff --git a/gpu/gpu.go b/gpu/gpu.go index ce0a10496..583bb79c6 100644 --- a/gpu/gpu.go +++ b/gpu/gpu.go @@ -231,7 +231,7 @@ func GetGPUInfo() GpuInfoList { // On windows we bundle the nvidia library one level above the runner dir depPath := "" if runtime.GOOS == "windows" && envconfig.RunnersDir != "" { - depPath = filepath.Dir(envconfig.RunnersDir) + depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir), "cuda") } // Load ALL libraries @@ -282,6 +282,12 @@ func GetGPUInfo() GpuInfoList { // Intel if envconfig.IntelGpu { oHandles = initOneAPIHandles() + // On windows we bundle the oneapi library one level above the runner dir + depPath = "" + if runtime.GOOS == "windows" && envconfig.RunnersDir != "" { + depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir), "oneapi") + } + for d := range oHandles.oneapi.num_drivers { if oHandles.oneapi == nil { // shouldn't happen @@ -306,7 +312,7 @@ func GetGPUInfo() GpuInfoList { gpuInfo.FreeMemory = uint64(memInfo.free) gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) - // TODO dependency path? + gpuInfo.DependencyPath = depPath oneapiGPUs = append(oneapiGPUs, gpuInfo) } } diff --git a/llm/generate/gen_windows.ps1 b/llm/generate/gen_windows.ps1 index 0eb48ffac..a3c53e63c 100644 --- a/llm/generate/gen_windows.ps1 +++ b/llm/generate/gen_windows.ps1 @@ -299,10 +299,12 @@ function build_cuda() { sign install - write-host "copying CUDA dependencies to ${script:SRC_DIR}\dist\windows-${script:ARCH}\" - cp "${script:CUDA_LIB_DIR}\cudart64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\" - cp "${script:CUDA_LIB_DIR}\cublas64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\" - cp "${script:CUDA_LIB_DIR}\cublasLt64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\" + rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\cuda\" + md "${script:SRC_DIR}\dist\windows-${script:ARCH}\cuda\" -ea 0 > $null + write-host "copying CUDA dependencies to ${script:SRC_DIR}\dist\windows-${script:ARCH}\cuda\" + cp "${script:CUDA_LIB_DIR}\cudart64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\cuda\" + cp "${script:CUDA_LIB_DIR}\cublas64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\cuda\" + cp "${script:CUDA_LIB_DIR}\cublasLt64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\cuda\" } else { write-host "Skipping CUDA generation step" } @@ -336,16 +338,18 @@ function build_oneapi() { sign install - cp "${env:ONEAPI_ROOT}\compiler\latest\bin\libirngmd.dll" "${script:distDir}" - cp "${env:ONEAPI_ROOT}\compiler\latest\bin\libmmd.dll" "${script:distDir}" - cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_level_zero.dll" "${script:distDir}" - cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_unified_runtime.dll" "${script:distDir}" - cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_win_proxy_loader.dll" "${script:distDir}" - cp "${env:ONEAPI_ROOT}\compiler\latest\bin\svml_dispmd.dll" "${script:distDir}" - cp "${env:ONEAPI_ROOT}\compiler\latest\bin\sycl7.dll" "${script:distDir}" - cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_core.2.dll" "${script:distDir}" - cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_sycl_blas.4.dll" "${script:distDir}" - cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_tbb_thread.2.dll" "${script:distDir}" + rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\" + md "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\" -ea 0 > $null + cp "${env:ONEAPI_ROOT}\compiler\latest\bin\libirngmd.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\" + cp "${env:ONEAPI_ROOT}\compiler\latest\bin\libmmd.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\" + cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_level_zero.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\" + cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_unified_runtime.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\" + cp "${env:ONEAPI_ROOT}\compiler\latest\bin\pi_win_proxy_loader.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\" + cp "${env:ONEAPI_ROOT}\compiler\latest\bin\svml_dispmd.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\" + cp "${env:ONEAPI_ROOT}\compiler\latest\bin\sycl7.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\" + cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_core.2.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\" + cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_sycl_blas.4.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\" + cp "${env:ONEAPI_ROOT}\mkl\latest\bin\mkl_tbb_thread.2.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\oneapi\" } else { Write-Host "Skipping oneAPI generation step" } diff --git a/llm/payload.go b/llm/payload.go index 20dcee7b5..9296db336 100644 --- a/llm/payload.go +++ b/llm/payload.go @@ -58,7 +58,7 @@ func availableServers() map[string]string { } // glob payloadsDir for files that start with ollama_ - pattern := filepath.Join(payloadsDir, "*") + pattern := filepath.Join(payloadsDir, "*", "ollama_*") files, err := filepath.Glob(pattern) if err != nil { @@ -69,7 +69,7 @@ func availableServers() map[string]string { servers := make(map[string]string) for _, file := range files { slog.Debug("availableServers : found", "file", file) - servers[filepath.Base(file)] = file + servers[filepath.Base(filepath.Dir(file))] = filepath.Dir(file) } return servers diff --git a/llm/server.go b/llm/server.go index 117565ba5..a94633ee6 100644 --- a/llm/server.go +++ b/llm/server.go @@ -271,8 +271,8 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr if runtime.GOOS == "windows" { pathEnv = "PATH" } - // prepend the server directory to LD_LIBRARY_PATH/PATH - libraryPaths := []string{dir} + // prepend the server directory to LD_LIBRARY_PATH/PATH and the parent dir for common dependencies + libraryPaths := []string{dir, filepath.Dir(dir)} if libraryPath, ok := os.LookupEnv(pathEnv); ok { // Append our runner directory to the path diff --git a/scripts/build_windows.ps1 b/scripts/build_windows.ps1 index 60de03073..b3991ce1f 100644 --- a/scripts/build_windows.ps1 +++ b/scripts/build_windows.ps1 @@ -103,19 +103,19 @@ function buildApp() { function gatherDependencies() { write-host "Gathering runtime dependencies" cd "${script:SRC_DIR}" - md "${script:DEPS_DIR}" -ea 0 > $null + md "${script:DEPS_DIR}\ollama_runners" -ea 0 > $null # TODO - this varies based on host build system and MSVC version - drive from dumpbin output # currently works for Win11 + MSVC 2019 + Cuda V11 - cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140.dll" "${script:DEPS_DIR}\" - cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\" - cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\" + cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140.dll" "${script:DEPS_DIR}\ollama_runners\" + cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\ollama_runners\" + cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\ollama_runners\" cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\" if ("${env:KEY_CONTAINER}") { write-host "about to sign" - foreach ($file in (get-childitem "${script:DEPS_DIR}/cu*.dll") + @("${script:SRC_DIR}\dist\ollama_welcome.ps1")){ + foreach ($file in (get-childitem "${script:DEPS_DIR}\cuda\cu*.dll") + @("${script:SRC_DIR}\dist\ollama_welcome.ps1")){ write-host "signing $file" & "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" ` /csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} $file From e873841cbb38d9d8f1b058e1338d88eaffbf9afa Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 18 Jun 2024 12:42:37 -0700 Subject: [PATCH 02/54] deepseek v2 graph --- llm/ggml.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/llm/ggml.go b/llm/ggml.go index 35b89d16e..4d9ba97a8 100644 --- a/llm/ggml.go +++ b/llm/ggml.go @@ -367,6 +367,17 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui 4*batch*(vocab+2*embedding), fullOffload, ) + case "deepseek2": + keys := uint64(llm.KV()["deepseek2.attention.key_length"].(uint32)) + fullOffload = max( + 4*batch*(3*embedding+vocab), + 4*batch*(3*embedding+2+context*(1+headsKV)+2*keys*headsKV), + ) + + partialOffload = max( + 4*batch*(3*embedding+vocab)+embedding*vocab*105/128, + 4*batch*(2*embedding+1+2*keys*headsKV+context+context*headsKV)+4*keys*context*headsKV+embedding*keys*headsKV*9/16, + ) } return From 1a1c99e3346da21bf2062fa266cf39da954c66a8 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Tue, 18 Jun 2024 17:13:54 -0700 Subject: [PATCH 03/54] Bump latest fedora cuda repo to 39 --- scripts/install.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/install.sh b/scripts/install.sh index 0f12d7e09..2a06c350a 100644 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -279,7 +279,7 @@ if ! check_gpu nvidia-smi || [ -z "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\ case $OS_NAME in centos|rhel) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -d '.' -f 1) ;; rocky) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -c1) ;; - fedora) [ $OS_VERSION -lt '37' ] && install_cuda_driver_yum $OS_NAME $OS_VERSION || install_cuda_driver_yum $OS_NAME '37';; + fedora) [ $OS_VERSION -lt '39' ] && install_cuda_driver_yum $OS_NAME $OS_VERSION || install_cuda_driver_yum $OS_NAME '39';; amzn) install_cuda_driver_yum 'fedora' '37' ;; debian) install_cuda_driver_apt $OS_NAME $OS_VERSION ;; ubuntu) install_cuda_driver_apt $OS_NAME $(echo $OS_VERSION | sed 's/\.//') ;; From 755b4e4fc291366595ed7bfb37c2a91ff5834df8 Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Wed, 19 Jun 2024 08:59:58 +0800 Subject: [PATCH 04/54] Revert "gpu: add env var for detecting Intel oneapi gpus (#5076)" This reverts commit 163cd3e77c42aafd003b9cb884b3a51cdbaea106. --- envconfig/config.go | 7 ------ gpu/gpu.go | 54 ++++++++++++++++++++++----------------------- 2 files changed, 26 insertions(+), 35 deletions(-) diff --git a/envconfig/config.go b/envconfig/config.go index e86f72e6a..bcf2e18ae 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -57,8 +57,6 @@ var ( SchedSpread bool // Set via OLLAMA_TMPDIR in the environment TmpDir string - // Set via OLLAMA_INTEL_GPU in the environment - IntelGpu bool // Set via CUDA_VISIBLE_DEVICES in the environment CudaVisibleDevices string @@ -103,7 +101,6 @@ func AsMap() map[string]EnvVar { ret["ROCR_VISIBLE_DEVICES"] = EnvVar{"ROCR_VISIBLE_DEVICES", RocrVisibleDevices, "Set which AMD devices are visible"} ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal, "Set which AMD devices are visible"} ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion, "Override the gfx used for all detected AMD GPUs"} - ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGpu, "Enable experimental Intel GPU detection"} } return ret } @@ -279,10 +276,6 @@ func LoadConfig() { slog.Error("invalid setting", "OLLAMA_HOST", Host, "error", err, "using default port", Host.Port) } - if set, err := strconv.ParseBool(clean("OLLAMA_INTEL_GPU")); err == nil { - IntelGpu = set - } - CudaVisibleDevices = clean("CUDA_VISIBLE_DEVICES") HipVisibleDevices = clean("HIP_VISIBLE_DEVICES") RocrVisibleDevices = clean("ROCR_VISIBLE_DEVICES") diff --git a/gpu/gpu.go b/gpu/gpu.go index ce0a10496..56a4dbfac 100644 --- a/gpu/gpu.go +++ b/gpu/gpu.go @@ -280,35 +280,33 @@ func GetGPUInfo() GpuInfoList { } // Intel - if envconfig.IntelGpu { - oHandles = initOneAPIHandles() - for d := range oHandles.oneapi.num_drivers { - if oHandles.oneapi == nil { - // shouldn't happen - slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers)) - continue - } - devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d)) - for i := range devCount { - gpuInfo := OneapiGPUInfo{ - GpuInfo: GpuInfo{ - Library: "oneapi", - }, - driverIndex: int(d), - gpuIndex: int(i), - } - // TODO - split bootstrapping from updating free memory - C.oneapi_check_vram(*oHandles.oneapi, C.int(d), i, &memInfo) - // TODO - convert this to MinimumMemory based on testing... - var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend. - memInfo.free = C.uint64_t(totalFreeMem) - gpuInfo.TotalMemory = uint64(memInfo.total) - gpuInfo.FreeMemory = uint64(memInfo.free) - gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) - gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) - // TODO dependency path? - oneapiGPUs = append(oneapiGPUs, gpuInfo) + oHandles = initOneAPIHandles() + for d := 0; oHandles.oneapi != nil && d < int(oHandles.oneapi.num_drivers); d++ { + if oHandles.oneapi == nil { + // shouldn't happen + slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers)) + continue + } + devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d)) + for i := range devCount { + gpuInfo := OneapiGPUInfo{ + GpuInfo: GpuInfo{ + Library: "oneapi", + }, + driverIndex: d, + gpuIndex: int(i), } + // TODO - split bootstrapping from updating free memory + C.oneapi_check_vram(*oHandles.oneapi, C.int(d), i, &memInfo) + // TODO - convert this to MinimumMemory based on testing... + var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend. + memInfo.free = C.uint64_t(totalFreeMem) + gpuInfo.TotalMemory = uint64(memInfo.total) + gpuInfo.FreeMemory = uint64(memInfo.free) + gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) + gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) + // TODO dependency path? + oneapiGPUs = append(oneapiGPUs, gpuInfo) } } From badf975e45005c45cf5d7794a18c88f4e069f89c Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Wed, 19 Jun 2024 09:00:51 +0800 Subject: [PATCH 05/54] get real func ptr. --- gpu/gpu_info_oneapi.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpu/gpu_info_oneapi.c b/gpu/gpu_info_oneapi.c index 004377456..3ff708ea2 100644 --- a/gpu/gpu_info_oneapi.c +++ b/gpu/gpu_info_oneapi.c @@ -50,7 +50,7 @@ void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp) { LOG(resp->oh.verbose, "dlsym: %s\n", l[i].s); *l[i].p = LOAD_SYMBOL(resp->oh.handle, l[i].s); - if (!l[i].p) { + if (!*(l[i].p)) { resp->oh.handle = NULL; char *msg = LOAD_ERR(); LOG(resp->oh.verbose, "dlerr: %s\n", msg); From 380e06e5bea06ae8ded37f47c37bd5d604194d3e Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Tue, 18 Jun 2024 13:29:38 -0700 Subject: [PATCH 06/54] types/model: remove Digest The Digest type in its current form is awkward to work with and presents challenges with regard to how it serializes via String using the '-' prefix. We currently only use this in ollama.com, so we'll move our specific needs around digest parsing and validation there. --- types/model/name.go | 55 ---------------------------------------- types/model/name_test.go | 34 ------------------------- 2 files changed, 89 deletions(-) diff --git a/types/model/name.go b/types/model/name.go index d85fd0c6c..e645a844c 100644 --- a/types/model/name.go +++ b/types/model/name.go @@ -4,7 +4,6 @@ package model import ( "cmp" - "encoding/hex" "errors" "fmt" "log/slog" @@ -371,57 +370,3 @@ func cutPromised(s, sep string) (before, after string, ok bool) { } return cmp.Or(before, MissingPart), cmp.Or(after, MissingPart), true } - -type DigestType byte - -const ( - DigestTypeInvalid DigestType = iota - DigestTypeSHA256 -) - -func (t DigestType) String() string { - switch t { - case DigestTypeSHA256: - return "sha256" - default: - return "invalid" - } -} - -type Digest struct { - Type DigestType - Sum [32]byte -} - -func ParseDigest(s string) (Digest, error) { - i := strings.IndexAny(s, "-:") - if i < 0 { - return Digest{}, fmt.Errorf("invalid digest %q", s) - } - typ, encSum := s[:i], s[i+1:] - if typ != "sha256" { - return Digest{}, fmt.Errorf("unsupported digest type %q", typ) - } - d := Digest{ - Type: DigestTypeSHA256, - } - n, err := hex.Decode(d.Sum[:], []byte(encSum)) - if err != nil { - return Digest{}, err - } - if n != 32 { - return Digest{}, fmt.Errorf("digest %q decoded to %d bytes; want 32", encSum, n) - } - return d, nil -} - -func (d Digest) String() string { - if d.Type == DigestTypeInvalid { - return "" - } - return fmt.Sprintf("sha256-%x", d.Sum) -} - -func (d Digest) IsValid() bool { - return d.Type != DigestTypeInvalid -} diff --git a/types/model/name_test.go b/types/model/name_test.go index 66ce4c339..008dd586c 100644 --- a/types/model/name_test.go +++ b/types/model/name_test.go @@ -284,40 +284,6 @@ func TestFilepathAllocs(t *testing.T) { } } -const ( - validSha256 = "sha256-1000000000000000000000000000000000000000000000000000000000000000" - validSha256Old = "sha256:1000000000000000000000000000000000000000000000000000000000000000" -) - -func TestParseDigest(t *testing.T) { - cases := []struct { - in string - want string - }{ - {"", ""}, // empty - {"sha123-12", ""}, // invalid type - {"sha256-", ""}, // invalid sum - {"sha256-123", ""}, // invalid odd length sum - - {validSha256, validSha256}, - {validSha256Old, validSha256}, - } - for _, tt := range cases { - t.Run(tt.in, func(t *testing.T) { - got, err := ParseDigest(tt.in) - if err != nil { - if tt.want != "" { - t.Errorf("parseDigest(%q) = %v; want %v", tt.in, err, tt.want) - } - return - } - if got.String() != tt.want { - t.Errorf("parseDigest(%q).String() = %q; want %q", tt.in, got, tt.want) - } - }) - } -} - func TestParseNameFromFilepath(t *testing.T) { cases := map[string]Name{ filepath.Join("host", "namespace", "model", "tag"): {Host: "host", Namespace: "namespace", Model: "model", Tag: "tag"}, From 52ce350b7aecd4bce9c42fe4aac1d85e47a6d774 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Wed, 19 Jun 2024 08:39:07 -0700 Subject: [PATCH 07/54] Fix bad symbol load detection pointer deref's weren't correct on a few libraries, which explains some crashes on older systems or miswired symlinks for discovery libraries. --- gpu/gpu_info_cudart.c | 2 +- gpu/gpu_info_nvcuda.c | 2 +- gpu/gpu_info_nvml.c | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/gpu/gpu_info_cudart.c b/gpu/gpu_info_cudart.c index 9db89529a..03f15a2c3 100644 --- a/gpu/gpu_info_cudart.c +++ b/gpu/gpu_info_cudart.c @@ -40,7 +40,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) { for (i = 0; l[i].s != NULL; i++) { *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s); - if (!l[i].p) { + if (!*(l[i].p)) { char *msg = LOAD_ERR(); LOG(resp->ch.verbose, "dlerr: %s\n", msg); UNLOAD_LIBRARY(resp->ch.handle); diff --git a/gpu/gpu_info_nvcuda.c b/gpu/gpu_info_nvcuda.c index 675ce5cc4..abe140844 100644 --- a/gpu/gpu_info_nvcuda.c +++ b/gpu/gpu_info_nvcuda.c @@ -43,7 +43,7 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) { for (i = 0; l[i].s != NULL; i++) { *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s); - if (!*l[i].p) { + if (!*(l[i].p)) { char *msg = LOAD_ERR(); LOG(resp->ch.verbose, "dlerr: %s\n", msg); UNLOAD_LIBRARY(resp->ch.handle); diff --git a/gpu/gpu_info_nvml.c b/gpu/gpu_info_nvml.c index ef0a97df2..11293e448 100644 --- a/gpu/gpu_info_nvml.c +++ b/gpu/gpu_info_nvml.c @@ -42,7 +42,7 @@ void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp) { // LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s); *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s); - if (!l[i].p) { + if (!*(l[i].p)) { resp->ch.handle = NULL; char *msg = LOAD_ERR(); LOG(resp->ch.verbose, "dlerr: %s\n", msg); From d34d88e41744bdcc36a425299af85ce762f3d30e Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Wed, 19 Jun 2024 08:57:41 -0700 Subject: [PATCH 08/54] Revert "Revert "gpu: add env var for detecting Intel oneapi gpus (#5076)"" This reverts commit 755b4e4fc291366595ed7bfb37c2a91ff5834df8. --- envconfig/config.go | 7 ++++++ gpu/gpu.go | 54 +++++++++++++++++++++++---------------------- 2 files changed, 35 insertions(+), 26 deletions(-) diff --git a/envconfig/config.go b/envconfig/config.go index bcf2e18ae..e86f72e6a 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -57,6 +57,8 @@ var ( SchedSpread bool // Set via OLLAMA_TMPDIR in the environment TmpDir string + // Set via OLLAMA_INTEL_GPU in the environment + IntelGpu bool // Set via CUDA_VISIBLE_DEVICES in the environment CudaVisibleDevices string @@ -101,6 +103,7 @@ func AsMap() map[string]EnvVar { ret["ROCR_VISIBLE_DEVICES"] = EnvVar{"ROCR_VISIBLE_DEVICES", RocrVisibleDevices, "Set which AMD devices are visible"} ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal, "Set which AMD devices are visible"} ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion, "Override the gfx used for all detected AMD GPUs"} + ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGpu, "Enable experimental Intel GPU detection"} } return ret } @@ -276,6 +279,10 @@ func LoadConfig() { slog.Error("invalid setting", "OLLAMA_HOST", Host, "error", err, "using default port", Host.Port) } + if set, err := strconv.ParseBool(clean("OLLAMA_INTEL_GPU")); err == nil { + IntelGpu = set + } + CudaVisibleDevices = clean("CUDA_VISIBLE_DEVICES") HipVisibleDevices = clean("HIP_VISIBLE_DEVICES") RocrVisibleDevices = clean("ROCR_VISIBLE_DEVICES") diff --git a/gpu/gpu.go b/gpu/gpu.go index 56a4dbfac..ce0a10496 100644 --- a/gpu/gpu.go +++ b/gpu/gpu.go @@ -280,33 +280,35 @@ func GetGPUInfo() GpuInfoList { } // Intel - oHandles = initOneAPIHandles() - for d := 0; oHandles.oneapi != nil && d < int(oHandles.oneapi.num_drivers); d++ { - if oHandles.oneapi == nil { - // shouldn't happen - slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers)) - continue - } - devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d)) - for i := range devCount { - gpuInfo := OneapiGPUInfo{ - GpuInfo: GpuInfo{ - Library: "oneapi", - }, - driverIndex: d, - gpuIndex: int(i), + if envconfig.IntelGpu { + oHandles = initOneAPIHandles() + for d := range oHandles.oneapi.num_drivers { + if oHandles.oneapi == nil { + // shouldn't happen + slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers)) + continue + } + devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d)) + for i := range devCount { + gpuInfo := OneapiGPUInfo{ + GpuInfo: GpuInfo{ + Library: "oneapi", + }, + driverIndex: int(d), + gpuIndex: int(i), + } + // TODO - split bootstrapping from updating free memory + C.oneapi_check_vram(*oHandles.oneapi, C.int(d), i, &memInfo) + // TODO - convert this to MinimumMemory based on testing... + var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend. + memInfo.free = C.uint64_t(totalFreeMem) + gpuInfo.TotalMemory = uint64(memInfo.total) + gpuInfo.FreeMemory = uint64(memInfo.free) + gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) + gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) + // TODO dependency path? + oneapiGPUs = append(oneapiGPUs, gpuInfo) } - // TODO - split bootstrapping from updating free memory - C.oneapi_check_vram(*oHandles.oneapi, C.int(d), i, &memInfo) - // TODO - convert this to MinimumMemory based on testing... - var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend. - memInfo.free = C.uint64_t(totalFreeMem) - gpuInfo.TotalMemory = uint64(memInfo.total) - gpuInfo.FreeMemory = uint64(memInfo.free) - gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) - gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) - // TODO dependency path? - oneapiGPUs = append(oneapiGPUs, gpuInfo) } } From 9d91e5e5875e2b2f8605ef15a7da9a616cb05171 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 19 Jun 2024 11:14:11 -0700 Subject: [PATCH 09/54] remove confusing log message --- llm/ext_server/server.cpp | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/llm/ext_server/server.cpp b/llm/ext_server/server.cpp index 18b3fa18d..492126a4f 100644 --- a/llm/ext_server/server.cpp +++ b/llm/ext_server/server.cpp @@ -56,7 +56,6 @@ struct server_params { std::string hostname = "127.0.0.1"; std::vector api_keys; std::string public_path = "examples/server/public"; - std::string chat_template = ""; int32_t port = 8080; int32_t read_timeout = 600; int32_t write_timeout = 600; @@ -427,16 +426,6 @@ struct llama_server_context return true; } - void validate_model_chat_template(server_params & sparams) { - llama_chat_message chat[] = {{"user", "test"}}; - std::vector buf(1); - int res = llama_chat_apply_template(model, nullptr, chat, 1, true, buf.data(), buf.size()); - if (res < 0) { - LOG_ERROR("The chat template comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {}); - sparams.chat_template = "chatml"; - } - } - void initialize() { // create slots all_slots_are_idle = true; @@ -2535,7 +2524,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, g invalid_param = true; break; } - sparams.chat_template = argv[i]; } else if (arg == "--override-kv") { @@ -3008,11 +2996,6 @@ int main(int argc, char **argv) { } const auto model_meta = llama.model_meta(); - if (sparams.chat_template.empty()) { // custom chat template is not supplied - // check if the template comes with the model is supported by us - llama.validate_model_chat_template(sparams); - } - // Middleware for API key validation auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool { // If API key is not set, skip validation From 9d8a4988e8eb743f985ebbb511a96e7496eac1c8 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Sat, 15 Jun 2024 16:30:37 -0700 Subject: [PATCH 10/54] Implement log rotation for tray app --- app/lifecycle/logging.go | 32 +++++++++++++++++++++++++ app/lifecycle/logging_test.go | 44 +++++++++++++++++++++++++++++++++++ app/lifecycle/paths.go | 11 +++++---- app/lifecycle/server.go | 2 +- docs/troubleshooting.md | 2 +- docs/windows.md | 4 ++-- 6 files changed, 86 insertions(+), 9 deletions(-) create mode 100644 app/lifecycle/logging_test.go diff --git a/app/lifecycle/logging.go b/app/lifecycle/logging.go index df2597a83..a8f1f7cdf 100644 --- a/app/lifecycle/logging.go +++ b/app/lifecycle/logging.go @@ -5,6 +5,8 @@ import ( "log/slog" "os" "path/filepath" + "strconv" + "strings" "github.com/ollama/ollama/envconfig" ) @@ -24,6 +26,7 @@ func InitLogging() { logFile = os.Stderr // TODO - write one-line to the app.log file saying we're running in console mode to help avoid confusion } else { + rotateLogs(AppLogFile) logFile, err = os.OpenFile(AppLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755) if err != nil { slog.Error(fmt.Sprintf("failed to create server log %v", err)) @@ -46,3 +49,32 @@ func InitLogging() { slog.Info("ollama app started") } + +func rotateLogs(logFile string) { + if _, err := os.Stat(logFile); os.IsNotExist(err) { + return + } + index := strings.LastIndex(logFile, ".") + pre := logFile[:index] + post := "." + logFile[index+1:] + for i := LogRotationCount; i > 0; i-- { + older := pre + "-" + strconv.Itoa(i) + post + newer := pre + "-" + strconv.Itoa(i-1) + post + if i == 1 { + newer = pre + post + } + if _, err := os.Stat(newer); err == nil { + if _, err := os.Stat(older); err == nil { + err := os.Remove(older) + if err != nil { + slog.Warn("Failed to remove older log", "older", older, "error", err) + continue + } + } + err := os.Rename(newer, older) + if err != nil { + slog.Warn("Failed to rotate log", "older", older, "newer", newer, "error", err) + } + } + } +} diff --git a/app/lifecycle/logging_test.go b/app/lifecycle/logging_test.go new file mode 100644 index 000000000..a2157ca2c --- /dev/null +++ b/app/lifecycle/logging_test.go @@ -0,0 +1,44 @@ +package lifecycle + +import ( + "os" + "path/filepath" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRotateLogs(t *testing.T) { + logDir := t.TempDir() + logFile := filepath.Join(logDir, "testlog.log") + + // No log exists + rotateLogs(logFile) + + require.NoError(t, os.WriteFile(logFile, []byte("1"), 0644)) + assert.FileExists(t, logFile) + // First rotation + rotateLogs(logFile) + assert.FileExists(t, filepath.Join(logDir, "testlog-1.log")) + assert.NoFileExists(t, filepath.Join(logDir, "testlog-2.log")) + assert.NoFileExists(t, logFile) + + // Should be a no-op without a new log + rotateLogs(logFile) + assert.FileExists(t, filepath.Join(logDir, "testlog-1.log")) + assert.NoFileExists(t, filepath.Join(logDir, "testlog-2.log")) + assert.NoFileExists(t, logFile) + + for i := 2; i <= LogRotationCount+1; i++ { + require.NoError(t, os.WriteFile(logFile, []byte(strconv.Itoa(i)), 0644)) + assert.FileExists(t, logFile) + rotateLogs(logFile) + assert.NoFileExists(t, logFile) + for j := 1; j < i; j++ { + assert.FileExists(t, filepath.Join(logDir, "testlog-"+strconv.Itoa(j)+".log")) + } + assert.NoFileExists(t, filepath.Join(logDir, "testlog-"+strconv.Itoa(i+1)+".log")) + } +} diff --git a/app/lifecycle/paths.go b/app/lifecycle/paths.go index fe07bce10..4d9f4c5a1 100644 --- a/app/lifecycle/paths.go +++ b/app/lifecycle/paths.go @@ -16,11 +16,12 @@ var ( AppDir = "/opt/Ollama" AppDataDir = "/opt/Ollama" // TODO - should there be a distinct log dir? - UpdateStageDir = "/tmp" - AppLogFile = "/tmp/ollama_app.log" - ServerLogFile = "/tmp/ollama.log" - UpgradeLogFile = "/tmp/ollama_update.log" - Installer = "OllamaSetup.exe" + UpdateStageDir = "/tmp" + AppLogFile = "/tmp/ollama_app.log" + ServerLogFile = "/tmp/ollama.log" + UpgradeLogFile = "/tmp/ollama_update.log" + Installer = "OllamaSetup.exe" + LogRotationCount = 5 ) func init() { diff --git a/app/lifecycle/server.go b/app/lifecycle/server.go index 0152ccd11..c178a1abf 100644 --- a/app/lifecycle/server.go +++ b/app/lifecycle/server.go @@ -54,7 +54,7 @@ func start(ctx context.Context, command string) (*exec.Cmd, error) { return nil, fmt.Errorf("failed to spawn server stderr pipe: %w", err) } - // TODO - rotation + rotateLogs(ServerLogFile) logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755) if err != nil { return nil, fmt.Errorf("failed to create server log: %w", err) diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md index 60d63c7d9..de29b344c 100644 --- a/docs/troubleshooting.md +++ b/docs/troubleshooting.md @@ -22,7 +22,7 @@ docker logs If manually running `ollama serve` in a terminal, the logs will be on that terminal. When you run Ollama on **Windows**, there are a few different locations. You can view them in the explorer window by hitting `+R` and type in: -- `explorer %LOCALAPPDATA%\Ollama` to view logs +- `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log` - `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH) - `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored - `explorer %TEMP%` where temporary executable files are stored in one or more `ollama*` directories diff --git a/docs/windows.md b/docs/windows.md index 832b3d431..abc0eb300 100644 --- a/docs/windows.md +++ b/docs/windows.md @@ -39,8 +39,8 @@ server. Ollama on Windows stores files in a few different locations. You can view them in the explorer window by hitting `+R` and type in: - `explorer %LOCALAPPDATA%\Ollama` contains logs, and downloaded updates - - *app.log* contains logs from the GUI application - - *server.log* contains the server logs + - *app.log* contains most resent logs from the GUI application + - *server.log* contains the most recent server logs - *upgrade.log* contains log output for upgrades - `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH) - `explorer %HOMEPATH%\.ollama` contains models and configuration From fedf71635ec77644f8477a86c6155217d9213a11 Mon Sep 17 00:00:00 2001 From: royjhan <65097070+royjhan@users.noreply.github.com> Date: Wed, 19 Jun 2024 14:19:02 -0700 Subject: [PATCH 11/54] Extend api/show and ollama show to return more model info (#4881) * API Show Extended * Initial Draft of Information Co-Authored-By: Patrick Devine * Clean Up * Descriptive arg error messages and other fixes * Second Draft of Show with Projectors Included * Remove Chat Template * Touches * Prevent wrapping from files * Verbose functionality * Docs * Address Feedback * Lint * Resolve Conflicts * Function Name * Tests for api/show model info * Show Test File * Add Projector Test * Clean routes * Projector Check * Move Show Test * Touches * Doc update --------- Co-authored-by: Patrick Devine --- api/types.go | 19 +++--- cmd/cmd.go | 143 +++++++++++++++++++++++++++++++++++++----- docs/api.md | 37 +++++++++-- server/routes.go | 35 +++++++++++ server/routes_test.go | 39 ++++++++++++ 5 files changed, 243 insertions(+), 30 deletions(-) diff --git a/api/types.go b/api/types.go index 7822a6034..0a1189e70 100644 --- a/api/types.go +++ b/api/types.go @@ -253,6 +253,7 @@ type ShowRequest struct { Model string `json:"model"` System string `json:"system"` Template string `json:"template"` + Verbose bool `json:"verbose"` Options map[string]interface{} `json:"options"` @@ -262,14 +263,16 @@ type ShowRequest struct { // ShowResponse is the response returned from [Client.Show]. type ShowResponse struct { - License string `json:"license,omitempty"` - Modelfile string `json:"modelfile,omitempty"` - Parameters string `json:"parameters,omitempty"` - Template string `json:"template,omitempty"` - System string `json:"system,omitempty"` - Details ModelDetails `json:"details,omitempty"` - Messages []Message `json:"messages,omitempty"` - ModifiedAt time.Time `json:"modified_at,omitempty"` + License string `json:"license,omitempty"` + Modelfile string `json:"modelfile,omitempty"` + Parameters string `json:"parameters,omitempty"` + Template string `json:"template,omitempty"` + System string `json:"system,omitempty"` + Details ModelDetails `json:"details,omitempty"` + Messages []Message `json:"messages,omitempty"` + ModelInfo map[string]any `json:"model_info,omitempty"` + ProjectorInfo map[string]any `json:"projector_info,omitempty"` + ModifiedAt time.Time `json:"modified_at,omitempty"` } // CopyRequest is the request passed to [Client.Copy]. diff --git a/cmd/cmd.go b/cmd/cmd.go index ae7c8da8f..68197f72d 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -579,10 +579,6 @@ func ShowHandler(cmd *cobra.Command, args []string) error { return err } - if len(args) != 1 { - return errors.New("missing model name") - } - license, errLicense := cmd.Flags().GetBool("license") modelfile, errModelfile := cmd.Flags().GetBool("modelfile") parameters, errParams := cmd.Flags().GetBool("parameters") @@ -625,8 +621,29 @@ func ShowHandler(cmd *cobra.Command, args []string) error { if flagsSet > 1 { return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified") - } else if flagsSet == 0 { - return errors.New("one of '--license', '--modelfile', '--parameters', '--system', or '--template' must be specified") + } + + if flagsSet == 1 { + req := api.ShowRequest{Name: args[0]} + resp, err := client.Show(cmd.Context(), &req) + if err != nil { + return err + } + + switch showType { + case "license": + fmt.Println(resp.License) + case "modelfile": + fmt.Println(resp.Modelfile) + case "parameters": + fmt.Println(resp.Parameters) + case "system": + fmt.Println(resp.System) + case "template": + fmt.Println(resp.Template) + } + + return nil } req := api.ShowRequest{Name: args[0]} @@ -635,22 +652,114 @@ func ShowHandler(cmd *cobra.Command, args []string) error { return err } - switch showType { - case "license": - fmt.Println(resp.License) - case "modelfile": - fmt.Println(resp.Modelfile) - case "parameters": - fmt.Println(resp.Parameters) - case "system": - fmt.Println(resp.System) - case "template": - fmt.Println(resp.Template) + arch := resp.ModelInfo["general.architecture"].(string) + + modelData := [][]string{ + {"arch", arch}, + {"parameters", resp.Details.ParameterSize}, + {"quantization", resp.Details.QuantizationLevel}, + {"context length", fmt.Sprintf("%v", resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64))}, + {"embedding length", fmt.Sprintf("%v", resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)].(float64))}, } + mainTableData := [][]string{ + {"Model"}, + {renderSubTable(modelData, false)}, + } + + if resp.ProjectorInfo != nil { + projectorData := [][]string{ + {"arch", "clip"}, + {"parameters", format.HumanNumber(uint64(resp.ProjectorInfo["general.parameter_count"].(float64)))}, + {"projector type", resp.ProjectorInfo["clip.projector_type"].(string)}, + {"embedding length", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.embedding_length"].(float64))}, + {"projection dimensionality", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.projection_dim"].(float64))}, + } + + mainTableData = append(mainTableData, + []string{"Projector"}, + []string{renderSubTable(projectorData, false)}, + ) + } + + if resp.Parameters != "" { + mainTableData = append(mainTableData, []string{"Parameters"}, []string{formatParams(resp.Parameters)}) + } + + if resp.System != "" { + mainTableData = append(mainTableData, []string{"System"}, []string{renderSubTable(twoLines(resp.System), true)}) + } + + if resp.License != "" { + mainTableData = append(mainTableData, []string{"License"}, []string{renderSubTable(twoLines(resp.License), true)}) + } + + table := tablewriter.NewWriter(os.Stdout) + table.SetAutoWrapText(false) + table.SetBorder(false) + table.SetAlignment(tablewriter.ALIGN_LEFT) + + for _, v := range mainTableData { + table.Append(v) + } + + table.Render() + return nil } +func renderSubTable(data [][]string, file bool) string { + var buf bytes.Buffer + table := tablewriter.NewWriter(&buf) + table.SetAutoWrapText(!file) + table.SetBorder(false) + table.SetNoWhiteSpace(true) + table.SetTablePadding("\t") + table.SetAlignment(tablewriter.ALIGN_LEFT) + + for _, v := range data { + table.Append(v) + } + + table.Render() + + renderedTable := buf.String() + lines := strings.Split(renderedTable, "\n") + for i, line := range lines { + lines[i] = "\t" + line + } + + return strings.Join(lines, "\n") +} + +func twoLines(s string) [][]string { + lines := strings.Split(s, "\n") + res := [][]string{} + + count := 0 + for _, line := range lines { + line = strings.TrimSpace(line) + if line != "" { + count++ + res = append(res, []string{line}) + if count == 2 { + return res + } + } + } + return res +} + +func formatParams(s string) string { + lines := strings.Split(s, "\n") + table := [][]string{} + + for _, line := range lines { + table = append(table, strings.Fields(line)) + } + return renderSubTable(table, false) +} + func CopyHandler(cmd *cobra.Command, args []string) error { client, err := api.ClientFromEnvironment() if err != nil { diff --git a/docs/api.md b/docs/api.md index 35f1def33..107b5211f 100644 --- a/docs/api.md +++ b/docs/api.md @@ -777,11 +777,12 @@ A single JSON object will be returned. POST /api/show ``` -Show information about a model including details, modelfile, template, parameters, license, and system prompt. +Show information about a model including details, modelfile, template, parameters, license, system prompt. ### Parameters - `name`: name of the model to show +- `verbose`: (optional) if set to `true`, returns full data for verbose response fields ### Examples @@ -798,14 +799,40 @@ curl http://localhost:11434/api/show -d '{ ```json { "modelfile": "# Modelfile generated by \"ollama show\"\n# To build a new Modelfile based on this one, replace the FROM line with:\n# FROM llava:latest\n\nFROM /Users/matt/.ollama/models/blobs/sha256:200765e1283640ffbd013184bf496e261032fa75b99498a9613be4e94d63ad52\nTEMPLATE \"\"\"{{ .System }}\nUSER: {{ .Prompt }}\nASSISTANT: \"\"\"\nPARAMETER num_ctx 4096\nPARAMETER stop \"\u003c/s\u003e\"\nPARAMETER stop \"USER:\"\nPARAMETER stop \"ASSISTANT:\"", - "parameters": "num_ctx 4096\nstop \u003c/s\u003e\nstop USER:\nstop ASSISTANT:", - "template": "{{ .System }}\nUSER: {{ .Prompt }}\nASSISTANT: ", + "parameters": "num_keep 24\nstop \"<|start_header_id|>\"\nstop \"<|end_header_id|>\"\nstop \"<|eot_id|>\"", + "template": "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|>", "details": { + "parent_model": "", "format": "gguf", "family": "llama", - "families": ["llama", "clip"], - "parameter_size": "7B", + "families": [ + "llama" + ], + "parameter_size": "8.0B", "quantization_level": "Q4_0" + }, + "model_info": { + "general.architecture": "llama", + "general.file_type": 2, + "general.parameter_count": 8030261248, + "general.quantization_version": 2, + "llama.attention.head_count": 32, + "llama.attention.head_count_kv": 8, + "llama.attention.layer_norm_rms_epsilon": 0.00001, + "llama.block_count": 32, + "llama.context_length": 8192, + "llama.embedding_length": 4096, + "llama.feed_forward_length": 14336, + "llama.rope.dimension_count": 128, + "llama.rope.freq_base": 500000, + "llama.vocab_size": 128256, + "tokenizer.ggml.bos_token_id": 128000, + "tokenizer.ggml.eos_token_id": 128009, + "tokenizer.ggml.merges": [], // populates if `verbose=true` + "tokenizer.ggml.model": "gpt2", + "tokenizer.ggml.pre": "llama-bpe", + "tokenizer.ggml.token_type": [], // populates if `verbose=true` + "tokenizer.ggml.tokens": [] // populates if `verbose=true` } } ``` diff --git a/server/routes.go b/server/routes.go index f36fe1b08..3d112e9f1 100644 --- a/server/routes.go +++ b/server/routes.go @@ -734,9 +734,44 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { fmt.Fprint(&sb, m.String()) resp.Modelfile = sb.String() + kvData, err := getKVData(m.ModelPath, req.Verbose) + if err != nil { + return nil, err + } + delete(kvData, "general.name") + delete(kvData, "tokenizer.chat_template") + resp.ModelInfo = kvData + + if len(m.ProjectorPaths) > 0 { + projectorData, err := getKVData(m.ProjectorPaths[0], req.Verbose) + if err != nil { + return nil, err + } + resp.ProjectorInfo = projectorData + } + return resp, nil } +func getKVData(digest string, verbose bool) (llm.KV, error) { + kvData, err := llm.LoadModel(digest) + if err != nil { + return nil, err + } + + kv := kvData.KV() + + if !verbose { + for k := range kv { + if t, ok := kv[k].([]any); len(t) > 5 && ok { + kv[k] = []any{} + } + } + } + + return kv, nil +} + func (s *Server) ListModelsHandler(c *gin.Context) { ms, err := Manifests() if err != nil { diff --git a/server/routes_test.go b/server/routes_test.go index 5e16cfeff..5a5c0fbba 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -19,6 +19,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/llm" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" @@ -212,6 +213,7 @@ func Test_Routes(t *testing.T) { "top_p 0.9", } assert.Equal(t, expectedParams, params) + assert.InDelta(t, 0, showResp.ModelInfo["general.parameter_count"], 1e-9, "Parameter count should be 0") }, }, } @@ -325,3 +327,40 @@ func TestCase(t *testing.T) { }) } } + +func TestShow(t *testing.T) { + t.Setenv("OLLAMA_MODELS", t.TempDir()) + envconfig.LoadConfig() + + var s Server + + createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "show-model", + Modelfile: fmt.Sprintf( + "FROM %s\nFROM %s", + createBinFile(t, llm.KV{"general.architecture": "test"}, nil), + createBinFile(t, llm.KV{"general.architecture": "clip"}, nil), + ), + }) + + w := createRequest(t, s.ShowModelHandler, api.ShowRequest{ + Name: "show-model", + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + var resp api.ShowResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatal(err) + } + + if resp.ModelInfo["general.architecture"] != "test" { + t.Fatal("Expected model architecture to be 'test', but got", resp.ModelInfo["general.architecture"]) + } + + if resp.ProjectorInfo["general.architecture"] != "clip" { + t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"]) + } +} From 23e899f32d9f7b3bbe0b902a95c23be5a1254409 Mon Sep 17 00:00:00 2001 From: Josh Yan Date: Thu, 20 Jun 2024 08:51:35 -0700 Subject: [PATCH 12/54] skip os.removeAll() if PID does not exist --- gpu/assets.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gpu/assets.go b/gpu/assets.go index f2adcf3e3..fdb3dd81d 100644 --- a/gpu/assets.go +++ b/gpu/assets.go @@ -87,6 +87,8 @@ func cleanupTmpDirs() { } } else { slog.Debug("failed to open ollama.pid", "path", d, "error", err) + // No pid, ignore this tmpdir + continue } err = os.RemoveAll(d) if err != nil { From 4ebb66c6623d85f4fb69db0406ddd05bdc2d893d Mon Sep 17 00:00:00 2001 From: Josh Yan Date: Thu, 20 Jun 2024 09:23:43 -0700 Subject: [PATCH 13/54] reformat error check --- gpu/assets.go | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/gpu/assets.go b/gpu/assets.go index fdb3dd81d..e2abfd58f 100644 --- a/gpu/assets.go +++ b/gpu/assets.go @@ -77,19 +77,20 @@ func cleanupTmpDirs() { continue } raw, err := os.ReadFile(filepath.Join(d, "ollama.pid")) - if err == nil { - pid, err := strconv.Atoi(string(raw)) - if err == nil { - if proc, err := os.FindProcess(pid); err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) { - // Another running ollama, ignore this tmpdir - continue - } - } - } else { - slog.Debug("failed to open ollama.pid", "path", d, "error", err) + if err != nil { + slog.Warn("failed to read ollama.pid", "path", d, "error", err) // No pid, ignore this tmpdir continue } + + pid, err := strconv.Atoi(string(raw)) + if err == nil { + if proc, err := os.FindProcess(pid); err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) { + // Another running ollama, ignore this tmpdir + continue + } + } + err = os.RemoveAll(d) if err != nil { slog.Debug("unable to cleanup stale tmpdir", "path", d, "error", err) From 662568d453debcf77d2e077ef98cfb2cfab8575e Mon Sep 17 00:00:00 2001 From: Josh Yan Date: Thu, 20 Jun 2024 09:30:59 -0700 Subject: [PATCH 14/54] err!=nil check --- gpu/assets.go | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/gpu/assets.go b/gpu/assets.go index e2abfd58f..073d2e813 100644 --- a/gpu/assets.go +++ b/gpu/assets.go @@ -84,16 +84,20 @@ func cleanupTmpDirs() { } pid, err := strconv.Atoi(string(raw)) - if err == nil { - if proc, err := os.FindProcess(pid); err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) { - // Another running ollama, ignore this tmpdir - continue - } + if err != nil { + slog.Warn("failed to parse pid", "path", d, "error", err) + continue } - err = os.RemoveAll(d) - if err != nil { - slog.Debug("unable to cleanup stale tmpdir", "path", d, "error", err) + proc, err := os.FindProcess(pid) + if err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) { + slog.Warn("found running ollama", "pid", pid, "path", d) + // Another running ollama, ignore this tmpdir + continue + } + + if err := os.Remove(d); err != nil { + slog.Warn("unable to cleanup stale tmpdir", "path", d, "error", err) } } } From 8e0641a9bffd0dde96e94a34bb3a4929da66c772 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 20 Jun 2024 09:40:17 -0700 Subject: [PATCH 15/54] handle asymmetric embedding KVs --- llm/ggml.go | 40 +++++++++++++++++++++++++++++++++------- llm/memory.go | 4 ++-- 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/llm/ggml.go b/llm/ggml.go index 4d9ba97a8..f02f0ff60 100644 --- a/llm/ggml.go +++ b/llm/ggml.go @@ -69,6 +69,30 @@ func (kv KV) HeadCountKV() uint64 { return 1 } +func (kv KV) EmbeddingHeadCount() uint64 { + if heads := kv.HeadCount(); heads > 0 { + return kv.EmbeddingLength() / kv.HeadCount() + } + + return 0 +} + +func (kv KV) EmbeddingHeadCountK() uint64 { + if k := kv.u64(fmt.Sprintf("%s.attention.key_length", kv.Architecture())); k > 0 { + return k + } + + return kv.EmbeddingHeadCount() +} + +func (kv KV) EmbeddingHeadCountV() uint64 { + if v := kv.u64(fmt.Sprintf("%s.attention.value_length", kv.Architecture())); v > 0 { + return v + } + + return kv.EmbeddingHeadCount() +} + func (kv KV) GQA() uint64 { return kv.HeadCount() / kv.HeadCountKV() } @@ -299,6 +323,9 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui headsKV := llm.KV().HeadCountKV() vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any))) + embeddingHeads := llm.KV().EmbeddingHeadCount() + embeddingHeadsK := llm.KV().EmbeddingHeadCountK() + layers := llm.Tensors().Layers() switch llm.KV().Architecture() { @@ -308,7 +335,7 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui partialOffload = 4 * batch * embedding partialOffload += max( // 4*batch*(4+6*embedding+context*(2*heads)+llm.KV().GQA()), - 4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV), + 4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embeddingHeads*headsKV), 4*batch*(embedding+vocab)+embedding*vocab*105/128, ) @@ -316,15 +343,15 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui // mixtral 8x22b ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32)) partialOffload = max( - 3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embedding/heads*headsKV), - 4*(context*batch*heads+context*embedding/heads*headsKV+batch*1024+embedding/heads*headsKV*batch), + 3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV), + 4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch), ) } else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok { // mixtral 8x7b ffnGateWeight1 := ffnGateWeight.Shape[1] fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1) partialOffload = max( - 4*batch*(3+embedding/heads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16, + 4*batch*(3+embeddingHeads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16, 4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16), ) } @@ -368,15 +395,14 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui fullOffload, ) case "deepseek2": - keys := uint64(llm.KV()["deepseek2.attention.key_length"].(uint32)) fullOffload = max( 4*batch*(3*embedding+vocab), - 4*batch*(3*embedding+2+context*(1+headsKV)+2*keys*headsKV), + 4*batch*(3*embedding+2+context*(1+headsKV)+2*embeddingHeadsK*headsKV), ) partialOffload = max( 4*batch*(3*embedding+vocab)+embedding*vocab*105/128, - 4*batch*(2*embedding+1+2*keys*headsKV+context+context*headsKV)+4*keys*context*headsKV+embedding*keys*headsKV*9/16, + 4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16, ) } diff --git a/llm/memory.go b/llm/memory.go index b8b862bd6..19b12cbfc 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -115,8 +115,8 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts slog.Warn("model missing blk.0 layer size") } - // fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv - var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV() + // fp16 k,v = sizeof(float16) * n_ctx * n_layer * (n_embd_head_k + n_embd_head_v) * n_head_kv + var kv uint64 = 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * (ggml.KV().EmbeddingHeadCountK() + ggml.KV().EmbeddingHeadCountV()) * ggml.KV().HeadCountKV() // KV is proportional to the number of layers layerSize += kv / ggml.KV().BlockCount() From 5bf5aeec0140a70eeb94b65c61dbb3b75ff33e56 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Thu, 20 Jun 2024 11:07:04 -0700 Subject: [PATCH 16/54] Refine mmap default logic on linux If we try to use mmap when the model is larger than the system free space, loading is slower than the no-mmap approach. --- llm/server.go | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/llm/server.go b/llm/server.go index ed5f288f2..da83416ee 100644 --- a/llm/server.go +++ b/llm/server.go @@ -81,7 +81,17 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr var err error var cpuRunner string var estimate MemoryEstimate - var systemMemory uint64 + var systemTotalMemory uint64 + var systemFreeMemory uint64 + + systemMemInfo, err := gpu.GetCPUMem() + if err != nil { + slog.Error("failed to lookup system memory", "error", err) + } else { + systemTotalMemory = systemMemInfo.TotalMemory + systemFreeMemory = systemMemInfo.FreeMemory + slog.Debug("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", systemFreeMemory) + } // If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info if opts.NumGPU == 0 { @@ -91,19 +101,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr cpuRunner = serverForCpu() estimate = EstimateGPULayers(gpus, ggml, projectors, opts) } else { - if gpus[0].Library == "metal" { - memInfo, err := gpu.GetCPUMem() - if err != nil { - slog.Error("failed to lookup system memory", "error", err) - } else { - systemMemory = memInfo.TotalMemory - slog.Debug("system memory", "total", format.HumanBytes2(systemMemory)) - } - } estimate = EstimateGPULayers(gpus, ggml, projectors, opts) switch { - case gpus[0].Library == "metal" && estimate.VRAMSize > systemMemory: + case gpus[0].Library == "metal" && estimate.VRAMSize > systemTotalMemory: // disable partial offloading when model is greater than total system memory as this // can lead to locking up the system opts.NumGPU = 0 @@ -211,7 +212,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr } // Windows CUDA should not use mmap for best performance - if (runtime.GOOS == "windows" && gpus[0].Library == "cuda") || opts.UseMMap == api.TriStateFalse { + // Linux with a model larger than free space, mmap leads to thrashing + if (runtime.GOOS == "windows" && gpus[0].Library == "cuda" && opts.UseMMap == api.TriStateUndefined) || + (runtime.GOOS == "linux" && systemFreeMemory < estimate.TotalSize && opts.UseMMap == api.TriStateUndefined) || + opts.UseMMap == api.TriStateFalse { params = append(params, "--no-mmap") } From 7e7749224c57ea4d7ae98e4d07dcb00e192a5c7c Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Fri, 21 Jun 2024 12:27:19 -0700 Subject: [PATCH 17/54] Fix use_mmap parsing for modelfiles Add the new tristate parsing logic for the code path for modelfiles, as well as a unit test. --- api/types.go | 13 ++++++++++ api/types_test.go | 63 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+) diff --git a/api/types.go b/api/types.go index 0a1189e70..95ed5d37e 100644 --- a/api/types.go +++ b/api/types.go @@ -608,6 +608,19 @@ func FormatParams(params map[string][]string) (map[string]interface{}, error) { } else { field := valueOpts.FieldByName(opt.Name) if field.IsValid() && field.CanSet() { + if reflect.PointerTo(field.Type()) == reflect.TypeOf((*TriState)(nil)) { + boolVal, err := strconv.ParseBool(vals[0]) + if err != nil { + return nil, fmt.Errorf("invalid bool value %s", vals) + } + if boolVal { + out[key] = TriStateTrue + } else { + out[key] = TriStateFalse + } + continue + } + switch field.Kind() { case reflect.Float32: floatVal, err := strconv.ParseFloat(vals[0], 32) diff --git a/api/types_test.go b/api/types_test.go index 7b4a0f83c..8b6c60c62 100644 --- a/api/types_test.go +++ b/api/types_test.go @@ -2,6 +2,7 @@ package api import ( "encoding/json" + "fmt" "math" "testing" "time" @@ -141,3 +142,65 @@ func TestUseMmapParsingFromJSON(t *testing.T) { }) } } + +func TestUseMmapFormatParams(t *testing.T) { + tests := []struct { + name string + req map[string][]string + exp TriState + err error + }{ + { + name: "True", + req: map[string][]string{ + "use_mmap": []string{"true"}, + }, + exp: TriStateTrue, + err: nil, + }, + { + name: "False", + req: map[string][]string{ + "use_mmap": []string{"false"}, + }, + exp: TriStateFalse, + err: nil, + }, + { + name: "Numeric True", + req: map[string][]string{ + "use_mmap": []string{"1"}, + }, + exp: TriStateTrue, + err: nil, + }, + { + name: "Numeric False", + req: map[string][]string{ + "use_mmap": []string{"0"}, + }, + exp: TriStateFalse, + err: nil, + }, + { + name: "invalid string", + req: map[string][]string{ + "use_mmap": []string{"foo"}, + }, + exp: TriStateUndefined, + err: fmt.Errorf("invalid bool value [foo]"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + resp, err := FormatParams(test.req) + require.Equal(t, err, test.err) + respVal, ok := resp["use_mmap"] + if test.exp != TriStateUndefined { + assert.True(t, ok, "resp: %v", resp) + assert.Equal(t, test.exp, respVal) + } + }) + } +} From e835ef183691db1cc7da30cfc61fb4b96b321e80 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 21 Jun 2024 13:30:43 -0700 Subject: [PATCH 18/54] fix: quantization with template --- server/images.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/server/images.go b/server/images.go index 53a957715..98794149e 100644 --- a/server/images.go +++ b/server/images.go @@ -414,17 +414,22 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio return err } - layers, err := parseFromFile(ctx, temp, "", fn) + layer, err := NewLayer(temp, baseLayer.MediaType) if err != nil { return err } - if len(layers) != 1 { - return errors.New("quantization failed") + if _, err := temp.Seek(0, io.SeekStart); err != nil { + return err } - baseLayer.Layer = layers[0].Layer - baseLayer.GGML = layers[0].GGML + ggml, _, err := llm.DecodeGGML(temp) + if err != nil { + return err + } + + baseLayer.Layer = layer + baseLayer.GGML = ggml } } From 17b7186cd759337fa98b626e82de150f3789b040 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Mon, 6 May 2024 17:47:52 -0700 Subject: [PATCH 19/54] Enable concurrency by default This adjusts our default settings to enable multiple models and parallel requests to a single model. Users can still override these by the same env var settings as before. Parallel has a direct impact on num_ctx, which in turn can have a significant impact on small VRAM GPUs so this change also refines the algorithm so that when parallel is not explicitly set by the user, we try to find a reasonable default that fits the model on their GPU(s). As before, multiple models will only load concurrently if they fully fit in VRAM. --- envconfig/config.go | 16 ++++---- llm/server.go | 13 ++---- server/sched.go | 98 +++++++++++++++++++++++++++++++++----------- server/sched_test.go | 80 +++++++++++++++++++++++------------- 4 files changed, 135 insertions(+), 72 deletions(-) diff --git a/envconfig/config.go b/envconfig/config.go index e86f72e6a..cb456448c 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -85,13 +85,13 @@ func AsMap() map[string]EnvVar { "OLLAMA_HOST": {"OLLAMA_HOST", Host, "IP Address for the ollama server (default 127.0.0.1:11434)"}, "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"}, "OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"}, - "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models (default 1)"}, + "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU (default 4)"}, "OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"}, "OLLAMA_MAX_VRAM": {"OLLAMA_MAX_VRAM", MaxVRAM, "Maximum VRAM"}, "OLLAMA_MODELS": {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"}, "OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"}, "OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"}, - "OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests (default 1)"}, + "OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests"}, "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"}, "OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"}, @@ -129,8 +129,8 @@ func clean(key string) string { func init() { // default values - NumParallel = 1 - MaxRunners = 1 + NumParallel = 0 + MaxRunners = 4 MaxQueuedRequests = 512 LoadConfig() @@ -205,8 +205,8 @@ func LoadConfig() { if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" { val, err := strconv.Atoi(onp) - if err != nil || val <= 0 { - slog.Error("invalid setting must be greater than zero", "OLLAMA_NUM_PARALLEL", onp, "error", err) + if err != nil { + slog.Error("invalid setting, ignoring", "OLLAMA_NUM_PARALLEL", onp, "error", err) } else { NumParallel = val } @@ -251,7 +251,7 @@ func LoadConfig() { if maxRunners != "" { m, err := strconv.Atoi(maxRunners) if err != nil { - slog.Error("invalid setting", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err) + slog.Error("invalid setting, ignoring", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err) } else { MaxRunners = m } @@ -260,7 +260,7 @@ func LoadConfig() { if onp := os.Getenv("OLLAMA_MAX_QUEUE"); onp != "" { p, err := strconv.Atoi(onp) if err != nil || p <= 0 { - slog.Error("invalid setting", "OLLAMA_MAX_QUEUE", onp, "error", err) + slog.Error("invalid setting, ignoring", "OLLAMA_MAX_QUEUE", onp, "error", err) } else { MaxQueuedRequests = p } diff --git a/llm/server.go b/llm/server.go index da83416ee..3cb5ac1f0 100644 --- a/llm/server.go +++ b/llm/server.go @@ -77,7 +77,7 @@ func LoadModel(model string) (*GGML, error) { // NewLlamaServer will run a server for the given GPUs // The gpu list must be a single family. -func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) { +func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) { var err error var cpuRunner string var estimate MemoryEstimate @@ -213,8 +213,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr // Windows CUDA should not use mmap for best performance // Linux with a model larger than free space, mmap leads to thrashing + // For CPU loads we want the memory to be allocated, not FS cache if (runtime.GOOS == "windows" && gpus[0].Library == "cuda" && opts.UseMMap == api.TriStateUndefined) || (runtime.GOOS == "linux" && systemFreeMemory < estimate.TotalSize && opts.UseMMap == api.TriStateUndefined) || + (gpus[0].Library == "cpu" && opts.UseMMap == api.TriStateUndefined) || opts.UseMMap == api.TriStateFalse { params = append(params, "--no-mmap") } @@ -227,15 +229,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr params = append(params, "--numa") } - numParallel := envconfig.NumParallel - - // TODO (jmorganca): multimodal models don't support parallel yet - // see https://github.com/ollama/ollama/issues/4165 - if len(projectors) > 0 { - numParallel = 1 - slog.Warn("multimodal models don't support parallel requests yet") - } - params = append(params, "--parallel", fmt.Sprintf("%d", numParallel)) if estimate.TensorSplit != "" { diff --git a/server/sched.go b/server/sched.go index 424395544..31ef560f5 100644 --- a/server/sched.go +++ b/server/sched.go @@ -23,6 +23,7 @@ type LlmRequest struct { ctx context.Context //nolint:containedctx model *Model opts api.Options + origNumCTX int // Track the initial ctx request sessionDuration time.Duration successCh chan *runnerRef errCh chan error @@ -38,8 +39,8 @@ type Scheduler struct { loaded map[string]*runnerRef loadedMu sync.Mutex - loadFn func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) - newServerFn func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) + loadFn func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) + newServerFn func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) getGpuFn func() gpu.GpuInfoList getCpuFn func() gpu.GpuInfoList reschedDelay time.Duration @@ -65,13 +66,10 @@ func InitScheduler(ctx context.Context) *Scheduler { // context must be canceled to decrement ref count and release the runner func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) { - // allocate a large enough kv cache for all parallel requests if opts.NumCtx < 4 { opts.NumCtx = 4 } - opts.NumCtx *= envconfig.NumParallel - req := &LlmRequest{ ctx: c, model: model, @@ -102,6 +100,7 @@ func (s *Scheduler) Run(ctx context.Context) { } func (s *Scheduler) processPending(ctx context.Context) { + maxRunnerFactor := 1 // number of GPUs or 1 for { select { case <-ctx.Done(): @@ -110,11 +109,25 @@ func (s *Scheduler) processPending(ctx context.Context) { case pending := <-s.pendingReqCh: // Block other requests until we get this pending request running pending.schedAttempts++ + if pending.origNumCTX == 0 { + pending.origNumCTX = pending.opts.NumCtx + } if pending.ctx.Err() != nil { slog.Debug("pending request cancelled or timed out, skipping scheduling") continue } + numParallel := envconfig.NumParallel + // TODO (jmorganca): multimodal models don't support parallel yet + // see https://github.com/ollama/ollama/issues/4165 + if len(pending.model.ProjectorPaths) > 0 && numParallel != 1 { + numParallel = 1 + slog.Warn("multimodal models don't support parallel requests yet") + } + // Keep NumCtx and numParallel in sync + if numParallel > 1 { + pending.opts.NumCtx = pending.origNumCTX * numParallel + } for { var runnerToExpire *runnerRef @@ -130,7 +143,7 @@ func (s *Scheduler) processPending(ctx context.Context) { pending.useLoadedRunner(runner, s.finishedReqCh) break } - } else if envconfig.MaxRunners > 0 && loadedCount >= envconfig.MaxRunners { + } else if envconfig.MaxRunners > 0 && loadedCount >= (maxRunnerFactor*envconfig.MaxRunners) { slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount) runnerToExpire = s.findRunnerToUnload() } else { @@ -142,6 +155,7 @@ func (s *Scheduler) processPending(ctx context.Context) { } else { gpus = s.getGpuFn() } + maxRunnerFactor = max(len(gpus), 1) // Load model for fitting ggml, err := llm.LoadModel(pending.model.ModelPath) @@ -152,26 +166,32 @@ func (s *Scheduler) processPending(ctx context.Context) { // Evaluate if the model will fit in the available system memory, or if we should unload a model first if len(gpus) == 1 && gpus[0].Library == "cpu" { + // simplifying assumption of defaultParallel when in CPU mode + if numParallel <= 0 { + numParallel = defaultParallel + pending.opts.NumCtx = pending.origNumCTX * numParallel + } + if loadedCount == 0 { slog.Debug("cpu mode with first model, loading") - s.loadFn(pending, ggml, gpus) + s.loadFn(pending, ggml, gpus, numParallel) break } runnerToExpire = s.maybeFindCPURunnerToUnload(pending, ggml, gpus) if runnerToExpire == nil { slog.Debug("cpu mode with available system memory or first model, loading") - s.loadFn(pending, ggml, gpus) + s.loadFn(pending, ggml, gpus, numParallel) break } // else we need to expire a runner } else if loadedCount == 0 { // No models loaded. Load the model but prefer the best fit. slog.Debug("loading first model", "model", pending.model.ModelPath) - g := pickBestFitGPUs(pending, ggml, gpus) + g := pickBestFitGPUs(pending, ggml, gpus, &numParallel) if g != nil { gpus = g } - s.loadFn(pending, ggml, gpus) + s.loadFn(pending, ggml, gpus, numParallel) break } @@ -186,10 +206,10 @@ func (s *Scheduler) processPending(ctx context.Context) { // Update free memory from currently loaded models s.updateFreeSpace(availGpus) - fitGpus := pickBestFitGPUs(pending, ggml, availGpus) + fitGpus := pickBestFitGPUs(pending, ggml, availGpus, &numParallel) if fitGpus != nil { slog.Debug("new model fits with existing models, loading") - s.loadFn(pending, ggml, fitGpus) + s.loadFn(pending, ggml, fitGpus, numParallel) break } @@ -350,8 +370,11 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm }() } -func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) { - llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts) +func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) { + if numParallel < 1 { + numParallel = 1 + } + llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel) if err != nil { // some older models are not compatible with newer versions of llama.cpp // show a generalized compatibility error until there is a better way to @@ -375,6 +398,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) loading: true, refCount: 1, } + runner.numParallel = numParallel runner.refMu.Lock() s.loadedMu.Lock() @@ -483,8 +507,9 @@ type runnerRef struct { expireTimer *time.Timer expiresAt time.Time - model *Model - modelPath string + model *Model + modelPath string + numParallel int *api.Options } @@ -525,6 +550,9 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool optsNew.NumGPU = -1 } + // Normalize the NumCtx for parallelism + optsExisting.NumCtx = optsExisting.NumCtx / runner.numParallel + ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed? @@ -611,22 +639,38 @@ func (a ByDuration) Less(i, j int) bool { // pickBestFitGPUs will try to find the optimal placement of the model in the available GPUs where the model fully fits // If the model can not be fit fully within the available GPU(s) nil is returned -func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) gpu.GpuInfoList { +// If numParallel is <= 0, this will attempt try to optimize parallism based on available VRAM, and adjust +// opts.NumCtx accordingly +func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel *int) gpu.GpuInfoList { var estimatedVRAM uint64 + + var numParallelToTry []int + if *numParallel <= 0 { + // If no specific parallel setting was provided, try larger then smaller, always end with 1 + numParallelToTry = append(numParallelToTry, 4, 1) + } else { + numParallelToTry = []int{*numParallel} + } + for _, gl := range gpus.ByLibrary() { var ok bool sgl := append(make(gpu.GpuInfoList, 0, len(gl)), gl...) // TODO - potentially sort by performance capability, existing models loaded, etc. + // TODO - Eliminate any GPUs that already have envconfig.MaxRunners loaded on them // Note: at present, this will favor more VRAM over faster GPU speed in mixed setups sort.Sort(sort.Reverse(gpu.ByFreeMemory(sgl))) // First attempt to fit the model into a single GPU - if !envconfig.SchedSpread { - for _, g := range sgl { - if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok { - slog.Debug("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM)) - return []gpu.GpuInfo{g} + for _, p := range numParallelToTry { + req.opts.NumCtx = req.origNumCTX * p + if !envconfig.SchedSpread { + for _, g := range sgl { + if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok { + slog.Info("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "parallel", p, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM)) + *numParallel = p + return []gpu.GpuInfo{g} + } } } } @@ -636,9 +680,13 @@ func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) gpu. // - try subsets of GPUs instead of just falling back to 1 or all in a family // Now try all the GPUs - if ok, estimatedVRAM = llm.PredictServerFit(sgl, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok { - slog.Debug("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "required", format.HumanBytes2(estimatedVRAM)) - return sgl + for _, p := range numParallelToTry { + req.opts.NumCtx = req.origNumCTX * p + if ok, estimatedVRAM = llm.PredictServerFit(sgl, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok { + slog.Info("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "parallel", p, "required", format.HumanBytes2(estimatedVRAM)) + *numParallel = p + return sgl + } } } return nil diff --git a/server/sched_test.go b/server/sched_test.go index 953288347..5e5913a7c 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -47,11 +47,11 @@ func TestLoad(t *testing.T) { sessionDuration: 2, } // Fail to load model first - s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) { + s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { return nil, fmt.Errorf("something failed to load model blah") } gpus := gpu.GpuInfoList{} - s.load(req, ggml, gpus) + s.load(req, ggml, gpus, 0) require.Empty(t, req.successCh) require.Len(t, req.errCh, 1) s.loadedMu.Lock() @@ -61,10 +61,10 @@ func TestLoad(t *testing.T) { require.Contains(t, err.Error(), "this model may be incompatible") server := &mockLlm{estimatedVRAM: 10, estimatedVRAMByGPU: map[string]uint64{}} - s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) { + s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { return server, nil } - s.load(req, ggml, gpus) + s.load(req, ggml, gpus, 0) select { case err := <-req.errCh: require.NoError(t, err) @@ -78,12 +78,12 @@ func TestLoad(t *testing.T) { req.model.ModelPath = "dummy_model_path" server.waitResp = fmt.Errorf("wait failure") - s.load(req, ggml, gpus) + s.load(req, ggml, gpus, 0) select { case err := <-req.errCh: require.Contains(t, err.Error(), "wait failure") case resp := <-req.successCh: - t.Errorf("unexpected success %v", resp) + t.Fatalf("unexpected success %v", resp) } s.loadedMu.Lock() runner := s.loaded["dummy_model_path"] @@ -102,7 +102,7 @@ type bundle struct { ggml *llm.GGML } -func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) { +func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { return scenario.srv, nil } @@ -200,7 +200,7 @@ func TestRequests(t *testing.T) { require.Empty(t, s.pendingReqCh) require.Empty(t, scenario1a.req.errCh) case <-ctx.Done(): - t.Errorf("timeout") + t.Fatal("timeout") } // Same runner as first request due to not needing a reload @@ -213,7 +213,7 @@ func TestRequests(t *testing.T) { require.Empty(t, s.pendingReqCh) require.Empty(t, scenario1b.req.errCh) case <-ctx.Done(): - t.Errorf("timeout") + t.Fatal("timeout") } // Trigger a reload @@ -231,7 +231,7 @@ func TestRequests(t *testing.T) { require.Empty(t, s.pendingReqCh) require.Empty(t, scenario2a.req.errCh) case <-ctx.Done(): - t.Errorf("timeout") + t.Fatal("timeout") } envconfig.MaxRunners = 1 @@ -247,7 +247,7 @@ func TestRequests(t *testing.T) { require.Empty(t, s.pendingReqCh) require.Empty(t, scenario3a.req.errCh) case <-ctx.Done(): - t.Errorf("timeout") + t.Fatal("timeout") } s.loadedMu.Lock() require.Len(t, s.loaded, 1) @@ -263,7 +263,7 @@ func TestRequests(t *testing.T) { require.Empty(t, s.pendingReqCh) require.Empty(t, scenario3b.req.errCh) case <-ctx.Done(): - t.Errorf("timeout") + t.Fatal("timeout") } s.loadedMu.Lock() require.Len(t, s.loaded, 2) @@ -279,7 +279,7 @@ func TestRequests(t *testing.T) { require.Empty(t, s.pendingReqCh) require.Empty(t, scenario3c.req.errCh) case <-ctx.Done(): - t.Errorf("timeout") + t.Fatal("timeout") } s.loadedMu.Lock() require.Len(t, s.loaded, 3) @@ -306,7 +306,7 @@ func TestRequests(t *testing.T) { require.Empty(t, s.pendingReqCh) require.Empty(t, scenario3d.req.errCh) case <-ctx.Done(): - t.Errorf("timeout") + t.Fatal("timeout") } s.loadedMu.Lock() require.Len(t, s.loaded, 2) @@ -349,7 +349,7 @@ func TestGetRunner(t *testing.T) { require.Empty(t, s.pendingReqCh) require.Empty(t, errCh1a) case <-ctx.Done(): - t.Errorf("timeout") + t.Fatal("timeout") } scenario1a.ctxDone() s.loadedMu.Lock() @@ -400,7 +400,7 @@ func TestPrematureExpired(t *testing.T) { slog.Info("sending premature expired event now") s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe case <-ctx.Done(): - t.Errorf("timeout") + t.Fatal("timeout") } time.Sleep(scenario1a.req.sessionDuration) scenario1a.ctxDone() @@ -427,7 +427,7 @@ func TestUseLoadedRunner(t *testing.T) { } finished := make(chan *LlmRequest) llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}} - r1 := &runnerRef{llama: llm1, sessionDuration: 1} + r1 := &runnerRef{llama: llm1, sessionDuration: 1, numParallel: 1} req.useLoadedRunner(r1, finished) require.Equal(t, uint(1), r1.refCount) require.Equal(t, time.Duration(2), r1.sessionDuration) @@ -435,7 +435,7 @@ func TestUseLoadedRunner(t *testing.T) { case success := <-req.successCh: require.Equal(t, r1, success) case <-ctx.Done(): - t.Errorf("timeout") + t.Fatal("timeout") } done() fin := <-finished @@ -461,8 +461,8 @@ func TestUpdateFreeSpace(t *testing.T) { gpus[1].FreeMemory = 1900 llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{"1": 50, "2": 50}} llm2 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{"1": 125, "2": 75}} - r1 := &runnerRef{llama: llm1, gpus: gpus} - r2 := &runnerRef{llama: llm2, gpus: gpus} + r1 := &runnerRef{llama: llm1, gpus: gpus, numParallel: 1} + r2 := &runnerRef{llama: llm2, gpus: gpus, numParallel: 1} s := InitScheduler(ctx) s.loadedMu.Lock() @@ -513,8 +513,8 @@ func TestFindRunnerToUnload(t *testing.T) { ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond) defer done() - r1 := &runnerRef{refCount: 1, sessionDuration: 1} - r2 := &runnerRef{sessionDuration: 2} + r1 := &runnerRef{refCount: 1, sessionDuration: 1, numParallel: 1} + r2 := &runnerRef{sessionDuration: 2, numParallel: 1} s := InitScheduler(ctx) s.loadedMu.Lock() @@ -536,9 +536,13 @@ func TestNeedsReload(t *testing.T) { llm := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}} do := api.DefaultOptions() runner := &runnerRef{ - model: &Model{AdapterPaths: []string{"adapter1"}, ProjectorPaths: []string{"projector1"}}, - Options: &do, - llama: llm, + model: &Model{ + AdapterPaths: []string{"adapter1"}, + ProjectorPaths: []string{"projector1"}, + }, + Options: &do, + llama: llm, + numParallel: 1, } req := &LlmRequest{ model: &Model{ @@ -581,8 +585,8 @@ func TestUnloadAllRunners(t *testing.T) { s := InitScheduler(ctx) s.unloadAllRunners() - r1 := &runnerRef{llama: llm1} - r2 := &runnerRef{llama: llm2} + r1 := &runnerRef{llama: llm1, numParallel: 1} + r2 := &runnerRef{llama: llm2, numParallel: 1} s.loadedMu.Lock() s.loaded["a"] = r1 @@ -596,14 +600,32 @@ func TestUnloadAllRunners(t *testing.T) { func TestUnload(t *testing.T) { llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}} - r1 := &runnerRef{llama: llm1} - r2 := &runnerRef{model: &Model{AdapterPaths: []string{"A"}}} + r1 := &runnerRef{llama: llm1, numParallel: 1} + r2 := &runnerRef{model: &Model{AdapterPaths: []string{"A"}}, numParallel: 1} r1.unload() require.True(t, llm1.closeCalled) r2.unload() require.Nil(t, r2.model) } +func TestAlreadyCanceled(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer done() + dctx, done2 := context.WithCancel(ctx) + done2() + scenario1a := newScenario(t, dctx, "ollama-model-1", 10) + scenario1a.req.sessionDuration = 0 + s := InitScheduler(ctx) + slog.Info("scenario1a") + s.pendingReqCh <- scenario1a.req + require.Len(t, s.pendingReqCh, 1) + s.Run(ctx) + time.Sleep(5 * time.Millisecond) + require.Empty(t, s.pendingReqCh) + require.Empty(t, scenario1a.req.errCh) + require.Empty(t, scenario1a.req.successCh) +} + type mockLlm struct { pingResp error waitResp error From 9929751cc8b415e7b83d5151742dad734e8b5efc Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Wed, 19 Jun 2024 13:35:38 -0700 Subject: [PATCH 20/54] Disable concurrency for AMD + Windows Until ROCm v6.2 ships, we wont be able to get accurate free memory reporting on windows, which makes automatic concurrency too risky. Users can still opt-in but will need to pay attention to model sizes otherwise they may thrash/page VRAM or cause OOM crashes. All other platforms and GPUs have accurate VRAM reporting wired up now, so we can turn on concurrency by default. --- envconfig/config.go | 8 ++++---- gpu/amd_windows.go | 5 +++-- gpu/types.go | 5 +++++ server/sched.go | 36 ++++++++++++++++++++++++++++++++---- 4 files changed, 44 insertions(+), 10 deletions(-) diff --git a/envconfig/config.go b/envconfig/config.go index cb456448c..0f0f7f058 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -85,13 +85,13 @@ func AsMap() map[string]EnvVar { "OLLAMA_HOST": {"OLLAMA_HOST", Host, "IP Address for the ollama server (default 127.0.0.1:11434)"}, "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"}, "OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"}, - "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU (default 4)"}, + "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU (default auto)"}, "OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"}, "OLLAMA_MAX_VRAM": {"OLLAMA_MAX_VRAM", MaxVRAM, "Maximum VRAM"}, "OLLAMA_MODELS": {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"}, "OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"}, "OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"}, - "OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests"}, + "OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests (default auto)"}, "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"}, "OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"}, @@ -129,8 +129,8 @@ func clean(key string) string { func init() { // default values - NumParallel = 0 - MaxRunners = 4 + NumParallel = 0 // Autoselect + MaxRunners = 0 // Autoselect MaxQueuedRequests = 512 LoadConfig() diff --git a/gpu/amd_windows.go b/gpu/amd_windows.go index 21585277a..8b6fabebb 100644 --- a/gpu/amd_windows.go +++ b/gpu/amd_windows.go @@ -115,8 +115,6 @@ func AMDGetGPUInfo() []RocmGPUInfo { continue } - // TODO revisit this once ROCm v6 is available on windows. - // v5.7 only reports VRAM used by this process, so it's completely wrong and unusable slog.Debug("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory)) slog.Debug("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory)) gpuInfo := RocmGPUInfo{ @@ -126,6 +124,9 @@ func AMDGetGPUInfo() []RocmGPUInfo { TotalMemory: totalMemory, FreeMemory: freeMemory, }, + // Free memory reporting on Windows is not reliable until we bump to ROCm v6.2 + UnreliableFreeMemory: true, + ID: strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices DependencyPath: libDir, MinimumMemory: rocmMinimumMemory, diff --git a/gpu/types.go b/gpu/types.go index 9920db5ff..2eaa9bae9 100644 --- a/gpu/types.go +++ b/gpu/types.go @@ -29,6 +29,11 @@ type GpuInfo struct { // Extra environment variables specific to the GPU as list of [key,value] EnvWorkarounds [][2]string `json:"envs,omitempty"` + // Set to true if we can NOT reliably discover FreeMemory. A value of true indicates + // the FreeMemory is best effort, and may over or under report actual memory usage + // False indicates FreeMemory can generally be trusted on this GPU + UnreliableFreeMemory bool + // GPU information ID string `json:"gpu_id"` // string to use for selection of this specific GPU Name string `json:"name"` // user friendly name if available diff --git a/server/sched.go b/server/sched.go index 31ef560f5..de8c9d281 100644 --- a/server/sched.go +++ b/server/sched.go @@ -46,6 +46,16 @@ type Scheduler struct { reschedDelay time.Duration } +// Default automatic value for number of models we allow per GPU +// Model will still need to fit in VRAM, but loading many small models +// on a large GPU can cause stalling +var defaultModelsPerGPU = 3 + +// Default automatic value for parallel setting +// Model will still need to fit in VRAM. If this setting wont fit +// we'll back off down to 1 to try to get it to fit +var defaultParallel = 4 + var ErrMaxQueue = fmt.Errorf("server busy, please try again. maximum pending requests exceeded") func InitScheduler(ctx context.Context) *Scheduler { @@ -100,7 +110,6 @@ func (s *Scheduler) Run(ctx context.Context) { } func (s *Scheduler) processPending(ctx context.Context) { - maxRunnerFactor := 1 // number of GPUs or 1 for { select { case <-ctx.Done(): @@ -143,7 +152,7 @@ func (s *Scheduler) processPending(ctx context.Context) { pending.useLoadedRunner(runner, s.finishedReqCh) break } - } else if envconfig.MaxRunners > 0 && loadedCount >= (maxRunnerFactor*envconfig.MaxRunners) { + } else if envconfig.MaxRunners > 0 && loadedCount >= envconfig.MaxRunners { slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount) runnerToExpire = s.findRunnerToUnload() } else { @@ -155,7 +164,26 @@ func (s *Scheduler) processPending(ctx context.Context) { } else { gpus = s.getGpuFn() } - maxRunnerFactor = max(len(gpus), 1) + + if envconfig.MaxRunners <= 0 { + // No user specified MaxRunners, so figure out what automatic setting to use + // If all GPUs have reliable free memory reporting, defaultModelsPerGPU * the number of GPUs + // if any GPU has unreliable free memory reporting, 1x the number of GPUs + allReliable := true + for _, gpu := range gpus { + if gpu.UnreliableFreeMemory { + allReliable = false + break + } + } + if allReliable { + envconfig.MaxRunners = defaultModelsPerGPU * len(gpus) + slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", envconfig.MaxRunners, "gpu_count", len(gpus)) + } else { + slog.Info("one or more GPUs detected that are unable to accurately report free memory - disabling default concurrency") + envconfig.MaxRunners = len(gpus) + } + } // Load model for fitting ggml, err := llm.LoadModel(pending.model.ModelPath) @@ -647,7 +675,7 @@ func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numP var numParallelToTry []int if *numParallel <= 0 { // If no specific parallel setting was provided, try larger then smaller, always end with 1 - numParallelToTry = append(numParallelToTry, 4, 1) + numParallelToTry = append(numParallelToTry, defaultParallel, 1) } else { numParallelToTry = []int{*numParallel} } From 9a9e7d83c416374782a984d7036f3f2ae5ddb78d Mon Sep 17 00:00:00 2001 From: royjhan <65097070+royjhan@users.noreply.github.com> Date: Fri, 21 Jun 2024 15:52:09 -0700 Subject: [PATCH 21/54] Docs (#5149) --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index 2fdc63cb3..978625731 100644 --- a/README.md +++ b/README.md @@ -182,6 +182,12 @@ $ ollama run llama3 "Summarize this file: $(cat README.md)" Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications. ``` +### Show model information + +``` +ollama show llama3 +``` + ### List models on your computer ``` From 642cee13426c994f90d5a95376025fe9a223891a Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Fri, 21 Jun 2024 15:59:41 -0700 Subject: [PATCH 22/54] Sort the ps output Provide consistent ordering for the ps command - longest duration listed first --- server/routes.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/server/routes.go b/server/routes.go index 3d112e9f1..a7f72edc2 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1224,6 +1224,11 @@ func (s *Server) ProcessHandler(c *gin.Context) { models = append(models, mr) } + slices.SortStableFunc(models, func(i, j api.ProcessModelResponse) int { + // longest duration remaining listed first + return cmp.Compare(j.ExpiresAt.Unix(), i.ExpiresAt.Unix()) + }) + c.JSON(http.StatusOK, api.ProcessResponse{Models: models}) } From 2aa91a937ba199ae5832c71ecc10221cc6420fa8 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Mon, 24 Jun 2024 20:14:03 -0700 Subject: [PATCH 23/54] cmd: defer stating model info until necessary (#5248) This commit changes the 'ollama run' command to defer fetching model information until it really needs it. That is, when in interactive mode. It also removes one such case where the model information is fetch in duplicate, just before calling generateInteractive and then again, first thing, in generateInteractive. This positively impacts the performance of the command: ; time ./before run llama3 'hi' Hi! It's nice to meet you. Is there something I can help you with, or would you like to chat? ./before run llama3 'hi' 0.02s user 0.01s system 2% cpu 1.168 total ; time ./before run llama3 'hi' Hi! It's nice to meet you. Is there something I can help you with, or would you like to chat? ./before run llama3 'hi' 0.02s user 0.01s system 2% cpu 1.220 total ; time ./before run llama3 'hi' Hi! It's nice to meet you. Is there something I can help you with, or would you like to chat? ./before run llama3 'hi' 0.02s user 0.01s system 2% cpu 1.217 total ; time ./after run llama3 'hi' Hi! It's nice to meet you. Is there something I can help you with, or would you like to chat? ./after run llama3 'hi' 0.02s user 0.01s system 4% cpu 0.652 total ; time ./after run llama3 'hi' Hi! It's nice to meet you. Is there something I can help you with, or would you like to chat? ./after run llama3 'hi' 0.01s user 0.01s system 5% cpu 0.498 total ; time ./after run llama3 'hi' Hi! It's nice to meet you. Is there something I can help you with or would you like to chat? ./after run llama3 'hi' 0.01s user 0.01s system 3% cpu 0.479 total ; time ./after run llama3 'hi' Hi! It's nice to meet you. Is there something I can help you with, or would you like to chat? ./after run llama3 'hi' 0.02s user 0.01s system 5% cpu 0.507 total ; time ./after run llama3 'hi' Hi! It's nice to meet you. Is there something I can help you with, or would you like to chat? ./after run llama3 'hi' 0.02s user 0.01s system 5% cpu 0.507 total --- cmd/cmd.go | 65 +++++++++++++++++++++++----------------------- cmd/interactive.go | 51 ++++++++++-------------------------- 2 files changed, 46 insertions(+), 70 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index 68197f72d..89b551f40 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -287,38 +287,12 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er } func RunHandler(cmd *cobra.Command, args []string) error { - client, err := api.ClientFromEnvironment() - if err != nil { - return err - } - - name := args[0] - - // check if the model exists on the server - show, err := client.Show(cmd.Context(), &api.ShowRequest{Name: name}) - var statusError api.StatusError - switch { - case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound: - if err := PullHandler(cmd, []string{name}); err != nil { - return err - } - - show, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name}) - if err != nil { - return err - } - case err != nil: - return err - } - interactive := true opts := runOptions{ - Model: args[0], - WordWrap: os.Getenv("TERM") == "xterm-256color", - Options: map[string]interface{}{}, - MultiModal: slices.Contains(show.Details.Families, "clip"), - ParentModel: show.Details.ParentModel, + Model: args[0], + WordWrap: os.Getenv("TERM") == "xterm-256color", + Options: map[string]interface{}{}, } format, err := cmd.Flags().GetString("format") @@ -362,11 +336,38 @@ func RunHandler(cmd *cobra.Command, args []string) error { } opts.WordWrap = !nowrap - if !interactive { - return generate(cmd, opts) + // Fill out the rest of the options based on information about the + // model. + client, err := api.ClientFromEnvironment() + if err != nil { + return err } - return generateInteractive(cmd, opts) + name := args[0] + info, err := func() (*api.ShowResponse, error) { + showReq := &api.ShowRequest{Name: name} + info, err := client.Show(cmd.Context(), showReq) + var se api.StatusError + if errors.As(err, &se) && se.StatusCode == http.StatusNotFound { + if err := PullHandler(cmd, []string{name}); err != nil { + return nil, err + } + return client.Show(cmd.Context(), &api.ShowRequest{Name: name}) + } + return info, err + }() + if err != nil { + return err + } + + opts.MultiModal = slices.Contains(info.Details.Families, "clip") + opts.ParentModel = info.Details.ParentModel + opts.Messages = append(opts.Messages, info.Messages...) + + if interactive { + return generateInteractive(cmd, opts) + } + return generate(cmd, opts) } func errFromUnknownKey(unknownKeyErr error) error { diff --git a/cmd/interactive.go b/cmd/interactive.go index 80a915474..0a2f429b6 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -31,65 +31,40 @@ const ( ) func loadModel(cmd *cobra.Command, opts *runOptions) error { - client, err := api.ClientFromEnvironment() - if err != nil { - return err - } - p := progress.NewProgress(os.Stderr) defer p.StopAndClear() spinner := progress.NewSpinner("") p.Add("", spinner) - showReq := api.ShowRequest{Name: opts.Model} - showResp, err := client.Show(cmd.Context(), &showReq) + client, err := api.ClientFromEnvironment() if err != nil { return err } - opts.MultiModal = slices.Contains(showResp.Details.Families, "clip") - opts.ParentModel = showResp.Details.ParentModel - - if len(showResp.Messages) > 0 { - opts.Messages = append(opts.Messages, showResp.Messages...) - } chatReq := &api.ChatRequest{ - Model: opts.Model, - Messages: []api.Message{}, + Model: opts.Model, + KeepAlive: opts.KeepAlive, } - if opts.KeepAlive != nil { - chatReq.KeepAlive = opts.KeepAlive - } - - err = client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error { + return client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error { p.StopAndClear() - if len(opts.Messages) > 0 { - for _, msg := range opts.Messages { - switch msg.Role { - case "user": - fmt.Printf(">>> %s\n", msg.Content) - case "assistant": - state := &displayResponseState{} - displayResponse(msg.Content, opts.WordWrap, state) - fmt.Println() - fmt.Println() - } + for _, msg := range opts.Messages { + switch msg.Role { + case "user": + fmt.Printf(">>> %s\n", msg.Content) + case "assistant": + state := &displayResponseState{} + displayResponse(msg.Content, opts.WordWrap, state) + fmt.Println() + fmt.Println() } } return nil }) - if err != nil { - return err - } - - return nil } func generateInteractive(cmd *cobra.Command, opts runOptions) error { - opts.Messages = make([]api.Message, 0) - err := loadModel(cmd, &opts) if err != nil { return err From cb42e607c5cf4d439ad4d5a93ed13c7d6a09fc34 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Mon, 24 Jun 2024 21:47:52 -0700 Subject: [PATCH 24/54] llm: speed up gguf decoding by a lot (#5246) Previously, some costly things were causing the loading of GGUF files and their metadata and tensor information to be VERY slow: * Too many allocations when decoding strings * Hitting disk for each read of each key and value, resulting in a not-okay amount of syscalls/disk I/O. The show API is now down to 33ms from 800ms+ for llama3 on a macbook pro m3. This commit also prevents collecting large arrays of values when decoding GGUFs (if desired). When such keys are encountered, their values are null, and are encoded as such in JSON. Also, this fixes a broken test that was not encoding valid GGUF. --- llm/ggla.go | 13 ++- llm/ggml.go | 25 ++++-- llm/ggml_test.go | 1 + llm/gguf.go | 130 +++++++++++++++++++-------- llm/memory_test.go | 19 ++-- llm/server.go | 11 ++- server/images.go | 2 +- server/model.go | 6 +- server/routes.go | 19 +++- server/sched.go | 2 +- server/sched_test.go | 6 +- util/bufioutil/buffer_seeker.go | 34 +++++++ util/bufioutil/buffer_seeker_test.go | 64 +++++++++++++ 13 files changed, 263 insertions(+), 69 deletions(-) create mode 100644 llm/ggml_test.go create mode 100644 util/bufioutil/buffer_seeker.go create mode 100644 util/bufioutil/buffer_seeker_test.go diff --git a/llm/ggla.go b/llm/ggla.go index a5d90b6cb..34c4f6ca3 100644 --- a/llm/ggla.go +++ b/llm/ggla.go @@ -53,7 +53,7 @@ func (llm *ggla) Tensors() Tensors { return llm.tensors } -func (llm *ggla) decode(rs io.ReadSeeker) error { +func (llm *ggla) decode(rs io.ReadSeeker) (retErr error) { var r uint32 if err := binary.Read(rs, binary.LittleEndian, &r); err != nil { return err @@ -69,9 +69,18 @@ func (llm *ggla) decode(rs io.ReadSeeker) error { for { var dims uint32 if err := binary.Read(rs, binary.LittleEndian, &dims); err != nil { + if errors.Is(err, io.EOF) { + return nil + } return err } + defer func() { + if errors.Is(retErr, io.EOF) { + retErr = io.ErrUnexpectedEOF + } + }() + var namesize uint32 if err := binary.Read(rs, binary.LittleEndian, &namesize); err != nil { return err @@ -108,7 +117,7 @@ func (llm *ggla) decode(rs io.ReadSeeker) error { return err } - if _, err := rs.Seek((offset+31)&-32, io.SeekStart); err != nil { + if _, err := rs.Seek((offset+31)&-32-offset, io.SeekCurrent); err != nil { return err } diff --git a/llm/ggml.go b/llm/ggml.go index f02f0ff60..d0d0b6ddc 100644 --- a/llm/ggml.go +++ b/llm/ggml.go @@ -6,6 +6,8 @@ import ( "fmt" "io" "strings" + + "github.com/ollama/ollama/util/bufioutil" ) type GGML struct { @@ -278,7 +280,18 @@ func DetectGGMLType(b []byte) string { } } -func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) { +// DecodeGGML decodes a GGML model from the given reader. +// +// It collects array values for arrays with a size less than or equal to +// maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If +// the maxArraySize is negative, all arrays are collected. +func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) { + if maxArraySize == 0 { + maxArraySize = 1024 + } + + rs = bufioutil.NewBufferedSeeker(rs, 32<<10) + var magic uint32 if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil { return nil, 0, err @@ -291,17 +304,15 @@ func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) { case FILE_MAGIC_GGLA: c = &containerGGLA{} case FILE_MAGIC_GGUF_LE: - c = &containerGGUF{ByteOrder: binary.LittleEndian} + c = &containerGGUF{ByteOrder: binary.LittleEndian, maxArraySize: maxArraySize} case FILE_MAGIC_GGUF_BE: - c = &containerGGUF{ByteOrder: binary.BigEndian} + c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize} default: return nil, 0, errors.New("invalid file magic") } model, err := c.Decode(rs) - if errors.Is(err, io.EOF) { - // noop - } else if err != nil { + if err != nil { return nil, 0, err } @@ -321,7 +332,7 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui embedding := llm.KV().EmbeddingLength() heads := llm.KV().HeadCount() headsKV := llm.KV().HeadCountKV() - vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any))) + vocab := uint64(llm.KV()["tokenizer.ggml.tokens"].(*array).size) embeddingHeads := llm.KV().EmbeddingHeadCount() embeddingHeadsK := llm.KV().EmbeddingHeadCountK() diff --git a/llm/ggml_test.go b/llm/ggml_test.go new file mode 100644 index 000000000..006c3ded8 --- /dev/null +++ b/llm/ggml_test.go @@ -0,0 +1 @@ +package llm diff --git a/llm/gguf.go b/llm/gguf.go index 234efe574..4d343a1bd 100644 --- a/llm/gguf.go +++ b/llm/gguf.go @@ -3,11 +3,10 @@ package llm import ( "bytes" "encoding/binary" + "encoding/json" "fmt" "io" "strings" - - "log/slog" ) type containerGGUF struct { @@ -29,6 +28,12 @@ type containerGGUF struct { NumTensor uint64 NumKV uint64 } + + maxArraySize int +} + +func (c *containerGGUF) canCollectArray(size int) bool { + return c.maxArraySize < 0 || size <= c.maxArraySize } func (c *containerGGUF) Name() string { @@ -54,7 +59,6 @@ func (c *containerGGUF) Decode(rs io.ReadSeeker) (model, error) { } model := newGGUF(c) - slog.Debug(fmt.Sprintf("model = %#v", model)) if err := model.Decode(rs); err != nil { return nil, err } @@ -85,6 +89,8 @@ type gguf struct { tensors []*Tensor parameters uint64 + + scratch [16 << 10]byte } func newGGUF(container *containerGGUF) *gguf { @@ -181,34 +187,34 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error { } // decode tensors - for i := 0; uint64(i) < llm.numTensor(); i++ { + for range llm.numTensor() { name, err := readGGUFString(llm, rs) if err != nil { - return err + return fmt.Errorf("failed to read tensor name: %w", err) } // dims is the number of dimensions in the tensor dims, err := readGGUF[uint32](llm, rs) if err != nil { - return err + return fmt.Errorf("failed to read tensor dimensions: %w", err) } shape := [4]uint64{1, 1, 1, 1} for i := 0; uint32(i) < dims; i++ { shape[i], err = readGGUF[uint64](llm, rs) if err != nil { - return err + return fmt.Errorf("failed to read tensor shape: %w", err) } } kind, err := readGGUF[uint32](llm, rs) if err != nil { - return err + return fmt.Errorf("failed to read tensor kind: %w", err) } offset, err := readGGUF[uint64](llm, rs) if err != nil { - return err + return fmt.Errorf("failed to read tensor offset: %w", err) } tensor := Tensor{ @@ -230,24 +236,19 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error { alignment = 32 } - offset, err := rs.Seek(0, io.SeekCurrent) - if err != nil { - return err - } - - padding := llm.padding(offset, int64(alignment)) - if _, err := rs.Seek(padding, io.SeekCurrent); err != nil { - return err - } - for _, tensor := range llm.tensors { - if _, err := rs.Seek(int64(tensor.Size()), io.SeekCurrent); err != nil { - return err + offset, err := rs.Seek(0, io.SeekCurrent) + if err != nil { + return fmt.Errorf("failed to get current offset: %w", err) } - padding := llm.padding(int64(tensor.Size()), int64(alignment)) + padding := llm.padding(offset, int64(alignment)) if _, err := rs.Seek(padding, io.SeekCurrent); err != nil { - return err + return fmt.Errorf("failed to seek to init padding: %w", err) + } + + if _, err := rs.Seek(int64(tensor.Size()), io.SeekCurrent); err != nil { + return fmt.Errorf("failed to seek to tensor: %w", err) } } @@ -285,22 +286,48 @@ func readGGUFV1String(llm *gguf, r io.Reader) (string, error) { return b.String(), nil } +func discardGGUFString(llm *gguf, r io.Reader) error { + buf := llm.scratch[:8] + _, err := io.ReadFull(r, buf) + if err != nil { + return err + } + + size := int(llm.ByteOrder.Uint64(buf)) + for size > 0 { + n, err := r.Read(llm.scratch[:min(size, cap(llm.scratch))]) + if err != nil { + return err + } + size -= n + } + return nil +} + func readGGUFString(llm *gguf, r io.Reader) (string, error) { if llm.Version == 1 { return readGGUFV1String(llm, r) } - var length uint64 - if err := binary.Read(r, llm.ByteOrder, &length); err != nil { + buf := llm.scratch[:8] + _, err := io.ReadFull(r, buf) + if err != nil { return "", err } - var b bytes.Buffer - if _, err := io.CopyN(&b, r, int64(length)); err != nil { + length := int(llm.ByteOrder.Uint64(buf)) + if length > len(llm.scratch) { + buf = make([]byte, length) + } else { + buf = llm.scratch[:length] + } + clear(buf) + + _, err = io.ReadFull(r, buf) + if err != nil { return "", err } - - return b.String(), nil + return string(buf), nil } func writeGGUFString(llm *gguf, w io.Writer, s string) error { @@ -316,7 +343,16 @@ func writeGGUFString(llm *gguf, w io.Writer, s string) error { return err } -func readGGUFV1Array(llm *gguf, r io.Reader) (a []any, err error) { +type array struct { + size int + values []any +} + +func (a *array) MarshalJSON() ([]byte, error) { + return json.Marshal(a.values) +} + +func readGGUFV1Array(llm *gguf, r io.Reader) (*array, error) { t, err := readGGUF[uint32](llm, r) if err != nil { return nil, err @@ -327,7 +363,12 @@ func readGGUFV1Array(llm *gguf, r io.Reader) (a []any, err error) { return nil, err } - for i := 0; uint32(i) < n; i++ { + a := &array{size: int(n)} + if llm.canCollectArray(int(n)) { + a.values = make([]any, 0, int(n)) + } + + for i := range n { var e any switch t { case ggufTypeUint8: @@ -361,13 +402,15 @@ func readGGUFV1Array(llm *gguf, r io.Reader) (a []any, err error) { return nil, err } - a = append(a, e) + if a.values != nil { + a.values[i] = e + } } - return + return a, nil } -func readGGUFArray(llm *gguf, r io.Reader) (a []any, err error) { +func readGGUFArray(llm *gguf, r io.Reader) (*array, error) { if llm.Version == 1 { return readGGUFV1Array(llm, r) } @@ -382,7 +425,12 @@ func readGGUFArray(llm *gguf, r io.Reader) (a []any, err error) { return nil, err } - for i := 0; uint64(i) < n; i++ { + a := &array{size: int(n)} + if llm.canCollectArray(int(n)) { + a.values = make([]any, int(n)) + } + + for i := range n { var e any switch t { case ggufTypeUint8: @@ -408,7 +456,11 @@ func readGGUFArray(llm *gguf, r io.Reader) (a []any, err error) { case ggufTypeBool: e, err = readGGUF[bool](llm, r) case ggufTypeString: - e, err = readGGUFString(llm, r) + if a.values != nil { + e, err = readGGUFString(llm, r) + } else { + err = discardGGUFString(llm, r) + } default: return nil, fmt.Errorf("invalid array type: %d", t) } @@ -416,10 +468,12 @@ func readGGUFArray(llm *gguf, r io.Reader) (a []any, err error) { return nil, err } - a = append(a, e) + if a.values != nil { + a.values[i] = e + } } - return + return a, nil } func writeGGUFArray[S ~[]E, E any](llm *gguf, w io.Writer, t uint32, s S) error { diff --git a/llm/memory_test.go b/llm/memory_test.go index 8eaa07715..f972f9275 100644 --- a/llm/memory_test.go +++ b/llm/memory_test.go @@ -22,13 +22,14 @@ func TestEstimateGPULayers(t *testing.T) { defer f.Close() gguf := NewGGUFV3(binary.LittleEndian) inputLayerCount := 5 + tensors := []Tensor{ - {Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}}, - {Name: "blk.1.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}}, - {Name: "blk.2.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}}, - {Name: "blk.3.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}}, - {Name: "blk.4.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}}, - {Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}}, + {Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))}, + {Name: "blk.1.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))}, + {Name: "blk.2.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))}, + {Name: "blk.3.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))}, + {Name: "blk.4.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))}, + {Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))}, } assert.Len(t, tensors, inputLayerCount+1) err = gguf.Encode(f, KV{ @@ -45,8 +46,10 @@ func TestEstimateGPULayers(t *testing.T) { }, tensors) require.NoError(t, err) - ggml, err := LoadModel(f.Name()) - require.NoError(t, err) + ggml, err := LoadModel(f.Name(), 0) + if err != nil { + t.Fatal(err) + } // Simple CPU scenario gpus := []gpu.GpuInfo{ diff --git a/llm/server.go b/llm/server.go index da83416ee..ad67138b5 100644 --- a/llm/server.go +++ b/llm/server.go @@ -60,7 +60,12 @@ type llmServer struct { sem *semaphore.Weighted } -func LoadModel(model string) (*GGML, error) { +// LoadModel will load a model from disk. The model must be in the GGML format. +// +// It collects array values for arrays with a size less than or equal to +// maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If +// the maxArraySize is negative, all arrays are collected. +func LoadModel(model string, maxArraySize int) (*GGML, error) { if _, err := os.Stat(model); err != nil { return nil, err } @@ -71,7 +76,7 @@ func LoadModel(model string) (*GGML, error) { } defer f.Close() - ggml, _, err := DecodeGGML(f) + ggml, _, err := DecodeGGML(f, maxArraySize) return ggml, err } @@ -412,7 +417,7 @@ func projectorMemoryRequirements(filename string) uint64 { } defer file.Close() - ggml, _, err := DecodeGGML(file) + ggml, _, err := DecodeGGML(file, 0) if err != nil { return 0 } diff --git a/server/images.go b/server/images.go index 98794149e..e949fb18a 100644 --- a/server/images.go +++ b/server/images.go @@ -423,7 +423,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio return err } - ggml, _, err := llm.DecodeGGML(temp) + ggml, _, err := llm.DecodeGGML(temp, 0) if err != nil { return err } diff --git a/server/model.go b/server/model.go index b262ea385..055ffd63a 100644 --- a/server/model.go +++ b/server/model.go @@ -63,7 +63,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe } defer blob.Close() - ggml, _, err := llm.DecodeGGML(blob) + ggml, _, err := llm.DecodeGGML(blob, 0) if err != nil { return nil, err } @@ -176,7 +176,7 @@ func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(a } defer bin.Close() - ggml, _, err := llm.DecodeGGML(bin) + ggml, _, err := llm.DecodeGGML(bin, 0) if err != nil { return nil, err } @@ -210,7 +210,7 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap var offset int64 for offset < stat.Size() { - ggml, n, err := llm.DecodeGGML(file) + ggml, n, err := llm.DecodeGGML(file, 0) if errors.Is(err, io.EOF) { break } else if err != nil { diff --git a/server/routes.go b/server/routes.go index 3d112e9f1..ff66663c0 100644 --- a/server/routes.go +++ b/server/routes.go @@ -754,7 +754,11 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { } func getKVData(digest string, verbose bool) (llm.KV, error) { - kvData, err := llm.LoadModel(digest) + maxArraySize := 0 + if verbose { + maxArraySize = -1 + } + kvData, err := llm.LoadModel(digest, maxArraySize) if err != nil { return nil, err } @@ -1101,11 +1105,20 @@ func Serve(ln net.Listener) error { schedCtx, schedDone := context.WithCancel(ctx) sched := InitScheduler(schedCtx) s := &Server{addr: ln.Addr(), sched: sched} - r := s.GenerateRoutes() + + http.Handle("/", s.GenerateRoutes()) slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version)) srvr := &http.Server{ - Handler: r, + // Use http.DefaultServeMux so we get net/http/pprof for + // free. + // + // TODO(bmizerany): Decide if we want to make this + // configurable so it is not exposed by default, or allow + // users to bind it to a different port. This was a quick + // and easy way to get pprof, but it may not be the best + // way. + Handler: nil, } // listen for a ctrl+c and stop any loaded llm diff --git a/server/sched.go b/server/sched.go index 424395544..0084b533b 100644 --- a/server/sched.go +++ b/server/sched.go @@ -144,7 +144,7 @@ func (s *Scheduler) processPending(ctx context.Context) { } // Load model for fitting - ggml, err := llm.LoadModel(pending.model.ModelPath) + ggml, err := llm.LoadModel(pending.model.ModelPath, 0) if err != nil { pending.errCh <- err break diff --git a/server/sched_test.go b/server/sched_test.go index 953288347..4a1cf72a0 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -128,14 +128,14 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV "tokenizer.ggml.scores": []float32{0}, "tokenizer.ggml.token_type": []int32{0}, }, []llm.Tensor{ - {Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}}, - {Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}}, + {Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))}, + {Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))}, }) require.NoError(t, err) fname := f.Name() model := &Model{Name: modelName, ModelPath: fname} - scenario.ggml, err = llm.LoadModel(model.ModelPath) + scenario.ggml, err = llm.LoadModel(model.ModelPath, 0) require.NoError(t, err) scenario.req = &LlmRequest{ diff --git a/util/bufioutil/buffer_seeker.go b/util/bufioutil/buffer_seeker.go new file mode 100644 index 000000000..8775fdb83 --- /dev/null +++ b/util/bufioutil/buffer_seeker.go @@ -0,0 +1,34 @@ +package bufioutil + +import ( + "bufio" + "io" +) + +type BufferedSeeker struct { + rs io.ReadSeeker + br *bufio.Reader +} + +func NewBufferedSeeker(rs io.ReadSeeker, size int) *BufferedSeeker { + return &BufferedSeeker{ + rs: rs, + br: bufio.NewReaderSize(rs, size), + } +} + +func (b *BufferedSeeker) Read(p []byte) (int, error) { + return b.br.Read(p) +} + +func (b *BufferedSeeker) Seek(offset int64, whence int) (int64, error) { + if whence == io.SeekCurrent { + offset -= int64(b.br.Buffered()) + } + n, err := b.rs.Seek(offset, whence) + if err != nil { + return 0, err + } + b.br.Reset(b.rs) + return n, nil +} diff --git a/util/bufioutil/buffer_seeker_test.go b/util/bufioutil/buffer_seeker_test.go new file mode 100644 index 000000000..87145f6b6 --- /dev/null +++ b/util/bufioutil/buffer_seeker_test.go @@ -0,0 +1,64 @@ +package bufioutil + +import ( + "bytes" + "io" + "strings" + "testing" +) + +func TestBufferedSeeker(t *testing.T) { + const alphabet = "abcdefghijklmnopqrstuvwxyz" + + bs := NewBufferedSeeker(strings.NewReader(alphabet), 0) // minReadBufferSize = 16 + + checkRead := func(buf []byte, expected string) { + t.Helper() + _, err := bs.Read(buf) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(buf, []byte(expected)) { + t.Fatalf("expected %s, got %s", expected, buf) + } + } + + // Read the first 5 bytes + buf := make([]byte, 5) + + checkRead(buf, "abcde") + + // Seek back to the beginning + _, err := bs.Seek(0, io.SeekStart) + if err != nil { + t.Fatal(err) + } + + // read 'a' + checkRead(buf[:1], "a") + + if bs.br.Buffered() == 0 { + t.Fatalf("totally unexpected sanity check failed") + } + + // Seek past 'b' + _, err = bs.Seek(1, io.SeekCurrent) + if err != nil { + t.Fatal(err) + } + checkRead(buf, "cdefg") + + // Seek back to the beginning + _, err = bs.Seek(0, io.SeekStart) + if err != nil { + t.Fatal(err) + } + checkRead(buf, "abcde") + + // Seek to the end + _, err = bs.Seek(-5, io.SeekEnd) + if err != nil { + t.Fatal(err) + } + checkRead(buf, "vwxyz") +} From 4d311eb731bb59512bcd17f1f33d60f3d9022837 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Wed, 26 Jun 2024 21:38:12 -0700 Subject: [PATCH 25/54] llm: architecture patch (#5316) --- llm/patches/07-gemma.diff | 305 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 305 insertions(+) create mode 100644 llm/patches/07-gemma.diff diff --git a/llm/patches/07-gemma.diff b/llm/patches/07-gemma.diff new file mode 100644 index 000000000..86eac3d17 --- /dev/null +++ b/llm/patches/07-gemma.diff @@ -0,0 +1,305 @@ +From 5cadb45f39d001ffbad95b690d6cf0abcb4a6d96 Mon Sep 17 00:00:00 2001 +From: Ollama maintainers +Date: Wed, 26 Jun 2024 16:18:09 -0700 +Subject: [PATCH] Architecture support + +--- + llama.cpp | 194 +++++++++++++++++++++++++++++++++++++++++++++++++++++- + 1 file changed, 193 insertions(+), 1 deletion(-) + +diff --git a/llama.cpp b/llama.cpp +index 61948751..3b4196f5 100644 +--- a/llama.cpp ++++ b/llama.cpp +@@ -217,6 +217,7 @@ enum llm_arch { + LLM_ARCH_INTERNLM2, + LLM_ARCH_MINICPM, + LLM_ARCH_GEMMA, ++ LLM_ARCH_GEMMA2, + LLM_ARCH_STARCODER2, + LLM_ARCH_MAMBA, + LLM_ARCH_XVERSE, +@@ -255,6 +256,7 @@ static const std::map LLM_ARCH_NAMES = { + { LLM_ARCH_INTERNLM2, "internlm2" }, + { LLM_ARCH_MINICPM, "minicpm" }, + { LLM_ARCH_GEMMA, "gemma" }, ++ { LLM_ARCH_GEMMA2, "gemma2" }, + { LLM_ARCH_STARCODER2, "starcoder2" }, + { LLM_ARCH_MAMBA, "mamba" }, + { LLM_ARCH_XVERSE, "xverse" }, +@@ -464,10 +466,12 @@ enum llm_tensor { + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_NORM_2, + LLM_TENSOR_ATTN_OUT_NORM, ++ LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_ATTN_ROT_EMBD, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_INP_SHEXP, + LLM_TENSOR_FFN_NORM, ++ LLM_TENSOR_FFN_POST_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, +@@ -960,6 +964,24 @@ static const std::map> LLM_TENSOR_NA + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, ++ { ++ LLM_ARCH_GEMMA2, ++ { ++ { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, ++ { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, ++ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, ++ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, ++ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, ++ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, ++ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, ++ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, ++ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, ++ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, ++ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, ++ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, ++ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, ++ }, ++ }, + { + LLM_ARCH_STARCODER2, + { +@@ -1941,6 +1963,8 @@ enum e_model { + MODEL_8x22B, + MODEL_16x12B, + MODEL_10B_128x3_66B, ++ MODEL_9B, ++ MODEL_27B, + }; + + static const size_t kiB = 1024; +@@ -2114,6 +2138,7 @@ struct llama_layer { + struct ggml_tensor * attn_out_norm_b; + struct ggml_tensor * attn_q_a_norm; + struct ggml_tensor * attn_kv_a_norm; ++ struct ggml_tensor * attn_post_norm; + + // attention + struct ggml_tensor * wq; +@@ -2136,6 +2161,7 @@ struct llama_layer { + // normalization + struct ggml_tensor * ffn_norm; + struct ggml_tensor * ffn_norm_b; ++ struct ggml_tensor * ffn_post_norm; + struct ggml_tensor * layer_out_norm; + struct ggml_tensor * layer_out_norm_b; + struct ggml_tensor * ffn_norm_exps; +@@ -4529,6 +4555,16 @@ static void llm_load_hparams( + } + } break; + case LLM_ARCH_GEMMA: ++ { ++ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ++ ++ switch (hparams.n_layer) { ++ case 18: model.type = e_model::MODEL_9B; break; ++ case 28: model.type = e_model::MODEL_27B; break; ++ default: model.type = e_model::MODEL_UNKNOWN; ++ } ++ } break; ++ case LLM_ARCH_GEMMA2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + +@@ -6305,6 +6341,40 @@ static bool llm_load_tensors( + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + } + } break; ++ case LLM_ARCH_GEMMA2: ++ { ++ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); ++ ++ // output ++ model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); ++ model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading ++ ++ const int64_t n_ff = hparams.n_ff; ++ const int64_t n_embd_head_k = hparams.n_embd_head_k; ++ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); ++ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); ++ ++ for (uint32_t i = 0; i < n_layer; ++i) { ++ ggml_context * ctx_layer = ctx_for_layer(i); ++ ggml_context * ctx_split = ctx_for_layer_split(i); ++ ++ auto & layer = model.layers[i]; ++ ++ layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); ++ ++ layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * hparams.n_head}); ++ layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); ++ layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); ++ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * hparams.n_head, n_embd}); ++ layer.attn_post_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}); ++ ++ layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); ++ layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); ++ layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); ++ layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); ++ layer.ffn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}); ++ } ++ } break; + case LLM_ARCH_STARCODER2: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); +@@ -10614,6 +10684,123 @@ struct llm_build_context { + return gf; + } + ++ struct ggml_cgraph * build_gemma2() { ++ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); ++ ++ const int64_t n_embd_head_k = hparams.n_embd_head_k; ++ ++ struct ggml_tensor * cur; ++ struct ggml_tensor * inpL; ++ ++ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); ++ ++ inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); ++ cb(inpL, "inp_scaled", -1); ++ ++ // inp_pos - contains the positions ++ struct ggml_tensor * inp_pos = build_inp_pos(); ++ ++ // KQ_mask (mask for 1 head, it will be broadcasted to all heads) ++ struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); ++ ++ for (int il = 0; il < n_layer; ++il) { ++ // norm ++ cur = llm_build_norm(ctx0, inpL, hparams, ++ model.layers[il].attn_norm, NULL, ++ LLM_NORM_RMS, cb, il); ++ cb(cur, "attn_norm", il); ++ ++ // self-attention ++ { ++ // compute Q and K and RoPE them ++ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); ++ cb(Qcur, "Qcur", il); ++ ++ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); ++ cb(Kcur, "Kcur", il); ++ ++ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); ++ cb(Vcur, "Vcur", il); ++ ++ Qcur = ggml_rope_ext( ++ ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr, ++ n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale, ++ ext_factor, attn_factor, beta_fast, beta_slow); ++ cb(Qcur, "Qcur", il); ++ ++ Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); ++ cb(Qcur, "Qcur_scaled", il); ++ ++ Kcur = ggml_rope_ext( ++ ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr, ++ n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale, ++ ext_factor, attn_factor, beta_fast, beta_slow); ++ cb(Kcur, "Kcur", il); ++ ++ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, ++ model.layers[il].wo, NULL, ++ Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); ++ } ++ ++ if (il == n_layer - 1) { ++ // skip computing output for unused tokens ++ struct ggml_tensor * inp_out_ids = build_inp_out_ids(); ++ cur = ggml_get_rows(ctx0, cur, inp_out_ids); ++ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); ++ } ++ ++ cur = llm_build_norm(ctx0, cur, hparams, ++ model.layers[il].attn_post_norm, NULL, ++ LLM_NORM_RMS, cb, il); ++ cb(cur, "attn_post_norm", il); ++ ++ struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL); ++ cb(sa_out, "sa_out", il); ++ ++ cur = llm_build_norm(ctx0, sa_out, hparams, ++ model.layers[il].ffn_norm, NULL, ++ LLM_NORM_RMS, cb, il); ++ cb(cur, "ffn_norm", il); ++ ++ // feed-forward network ++ { ++ cur = llm_build_ffn(ctx0, cur, ++ model.layers[il].ffn_up, NULL, ++ model.layers[il].ffn_gate, NULL, ++ model.layers[il].ffn_down, NULL, ++ NULL, ++ LLM_FFN_GELU, LLM_FFN_PAR, cb, il); ++ cb(cur, "ffn_out", il); ++ } ++ ++ cur = llm_build_norm(ctx0, cur, hparams, ++ model.layers[il].ffn_post_norm, NULL, ++ LLM_NORM_RMS, cb, -1); ++ cb(cur, "ffn_post_norm", -1); ++ ++ cur = ggml_add(ctx0, cur, sa_out); ++ cb(cur, "l_out", il); ++ ++ // input for next layer ++ inpL = cur; ++ } ++ ++ cur = inpL; ++ ++ cur = llm_build_norm(ctx0, cur, hparams, ++ model.output_norm, NULL, ++ LLM_NORM_RMS, cb, -1); ++ cb(cur, "result_norm", -1); ++ ++ // lm_head ++ cur = ggml_mul_mat(ctx0, model.output, cur); ++ cb(cur, "result_output", -1); ++ ++ ggml_build_forward_expand(gf, cur); ++ ++ return gf; ++ } ++ + struct ggml_cgraph * build_starcoder2() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + +@@ -11847,6 +12034,10 @@ static struct ggml_cgraph * llama_build_graph( + { + result = llm.build_gemma(); + } break; ++ case LLM_ARCH_GEMMA2: ++ { ++ result = llm.build_gemma2(); ++ } break; + case LLM_ARCH_STARCODER2: + { + result = llm.build_starcoder2(); +@@ -16671,6 +16862,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { + case LLM_ARCH_PHI2: + case LLM_ARCH_PHI3: + case LLM_ARCH_GEMMA: ++ case LLM_ARCH_GEMMA2: + case LLM_ARCH_STARCODER2: + case LLM_ARCH_GPTNEOX: + return LLAMA_ROPE_TYPE_NEOX; +@@ -18551,7 +18743,7 @@ static int32_t llama_chat_apply_template_internal( + if (add_ass) { + ss << "assistant\n"; + } +- } else if (tmpl == "gemma" || tmpl.find("") != std::string::npos) { ++ } else if (tmpl == "gemma" || tmpl == "gemma2" || tmpl.find("") != std::string::npos) { + // google/gemma-7b-it + std::string system_prompt = ""; + for (auto message : chat) { +-- +2.45.2 + From 123a722a6f541e300bc8e34297ac378ebe23f527 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 26 Jun 2024 21:38:21 -0700 Subject: [PATCH 26/54] zip: prevent extracting files into parent dirs (#5314) --- cmd/cmd.go | 6 +-- server/model.go | 57 ++++++++++++++++++--------- server/model_test.go | 92 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 133 insertions(+), 22 deletions(-) create mode 100644 server/model_test.go diff --git a/cmd/cmd.go b/cmd/cmd.go index 89b551f40..909e8e4b2 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -162,9 +162,6 @@ func tempZipFiles(path string) (string, error) { } defer tempfile.Close() - zipfile := zip.NewWriter(tempfile) - defer zipfile.Close() - detectContentType := func(path string) (string, error) { f, err := os.Open(path) if err != nil { @@ -233,6 +230,9 @@ func tempZipFiles(path string) (string, error) { files = append(files, tks...) } + zipfile := zip.NewWriter(tempfile) + defer zipfile.Close() + for _, file := range files { f, err := os.Open(file) if err != nil { diff --git a/server/model.go b/server/model.go index 055ffd63a..d56e641ba 100644 --- a/server/model.go +++ b/server/model.go @@ -11,6 +11,7 @@ import ( "net/http" "os" "path/filepath" + "strings" "github.com/ollama/ollama/api" "github.com/ollama/ollama/convert" @@ -77,62 +78,80 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe return layers, nil } -func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) { +func extractFromZipFile(p string, file *os.File, fn func(api.ProgressResponse)) error { stat, err := file.Stat() if err != nil { - return nil, err + return err } r, err := zip.NewReader(file, stat.Size()) if err != nil { - return nil, err + return err } - tempdir, err := os.MkdirTemp(filepath.Dir(file.Name()), "") - if err != nil { - return nil, err - } - defer os.RemoveAll(tempdir) - fn(api.ProgressResponse{Status: "unpacking model metadata"}) for _, f := range r.File { + n := filepath.Join(p, f.Name) + if !strings.HasPrefix(n, p) { + slog.Warn("skipped extracting file outside of context", "name", f.Name) + continue + } + + if err := os.MkdirAll(filepath.Dir(n), 0o750); err != nil { + return err + } + // TODO(mxyng): this should not write out all files to disk - outfile, err := os.Create(filepath.Join(tempdir, f.Name)) + outfile, err := os.Create(n) if err != nil { - return nil, err + return err } defer outfile.Close() infile, err := f.Open() if err != nil { - return nil, err + return err } defer infile.Close() if _, err = io.Copy(outfile, infile); err != nil { - return nil, err + return err } if err := outfile.Close(); err != nil { - return nil, err + return err } if err := infile.Close(); err != nil { - return nil, err + return err } } - mf, err := convert.GetModelFormat(tempdir) + return nil +} + +func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) { + tempDir, err := os.MkdirTemp(filepath.Dir(file.Name()), "") + if err != nil { + return nil, err + } + defer os.RemoveAll(tempDir) + + if err := extractFromZipFile(tempDir, file, fn); err != nil { + return nil, err + } + + mf, err := convert.GetModelFormat(tempDir) if err != nil { return nil, err } - params, err := mf.GetParams(tempdir) + params, err := mf.GetParams(tempDir) if err != nil { return nil, err } - mArch, err := mf.GetModelArch("", tempdir, params) + mArch, err := mf.GetModelArch("", tempDir, params) if err != nil { return nil, err } @@ -150,7 +169,7 @@ func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(a // TODO(mxyng): this should write directly into a layer // e.g. NewLayer(arch.Reader(), "application/vnd.ollama.image.model") - temp, err := os.CreateTemp(tempdir, "fp16") + temp, err := os.CreateTemp(tempDir, "fp16") if err != nil { return nil, err } diff --git a/server/model_test.go b/server/model_test.go new file mode 100644 index 000000000..c3023eb2b --- /dev/null +++ b/server/model_test.go @@ -0,0 +1,92 @@ +package server + +import ( + "archive/zip" + "bytes" + "io" + "os" + "path/filepath" + "slices" + "testing" + + "github.com/ollama/ollama/api" +) + +func createZipFile(t *testing.T, name string) *os.File { + t.Helper() + + f, err := os.CreateTemp(t.TempDir(), "") + if err != nil { + t.Fatal(err) + } + + zf := zip.NewWriter(f) + defer zf.Close() + + zh, err := zf.CreateHeader(&zip.FileHeader{Name: name}) + if err != nil { + t.Fatal(err) + } + + if _, err := io.Copy(zh, bytes.NewReader([]byte(""))); err != nil { + t.Fatal(err) + } + + return f +} + +func TestExtractFromZipFile(t *testing.T) { + cases := []struct { + name string + expect []string + }{ + { + name: "good", + expect: []string{"good"}, + }, + { + name: filepath.Join("..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "bad"), + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + f := createZipFile(t, tt.name) + defer f.Close() + + tempDir := t.TempDir() + if err := extractFromZipFile(tempDir, f, func(api.ProgressResponse) {}); err != nil { + t.Fatal(err) + } + + var matches []string + if err := filepath.Walk(tempDir, func(p string, fi os.FileInfo, err error) error { + if err != nil { + return err + } + + if !fi.IsDir() { + matches = append(matches, p) + } + + return nil + }); err != nil { + t.Fatal(err) + } + + var actual []string + for _, match := range matches { + rel, err := filepath.Rel(tempDir, match) + if err != nil { + t.Error(err) + } + + actual = append(actual, rel) + } + + if !slices.Equal(actual, tt.expect) { + t.Fatalf("expected %d files, got %d", len(tt.expect), len(matches)) + } + }) + } +} From 2cc7d050124929ae4745633fddf053585a22f0a2 Mon Sep 17 00:00:00 2001 From: Michael Date: Thu, 27 Jun 2024 12:45:16 -0400 Subject: [PATCH 27/54] update readme for gemma 2 (#5333) * update readme for gemma 2 --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 978625731..72ed8fa5e 100644 --- a/README.md +++ b/README.md @@ -53,8 +53,8 @@ Here are some example models that can be downloaded: | Llama 3 | 70B | 40GB | `ollama run llama3:70b` | | Phi 3 Mini | 3.8B | 2.3GB | `ollama run phi3` | | Phi 3 Medium | 14B | 7.9GB | `ollama run phi3:medium` | -| Gemma | 2B | 1.4GB | `ollama run gemma:2b` | -| Gemma | 7B | 4.8GB | `ollama run gemma:7b` | +| Gemma 2 | 9B | 5.5GB | `ollama run gemma2` | +| Gemma 2 | 27B | 16GB | `ollama run gemma2:27b` | | Mistral | 7B | 4.1GB | `ollama run mistral` | | Moondream 2 | 1.4B | 829MB | `ollama run moondream` | | Neural Chat | 7B | 4.1GB | `ollama run neural-chat` | From 4e986a823ca47eb16f563d15a6fe4cc393a00715 Mon Sep 17 00:00:00 2001 From: Josh Yan Date: Thu, 27 Jun 2024 10:59:15 -0700 Subject: [PATCH 28/54] unquote, trimp space --- parser/parser.go | 9 ++++++++- parser/parser_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/parser/parser.go b/parser/parser.go index 686a1e695..fa60ebc0f 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -125,6 +125,7 @@ func ParseFile(r io.Reader) (*File, error) { // pass case stateValue: s, ok := unquote(b.String()) + if !ok || isSpace(r) { if _, err := b.WriteRune(r); err != nil { return nil, err @@ -158,7 +159,13 @@ func ParseFile(r io.Reader) (*File, error) { case stateComment, stateNil: // pass; nothing to flush case stateValue: - s, ok := unquote(b.String()) + var s string + var ok bool + if cmd.Name == "model" { + s, ok = unquote(strings.TrimSpace(b.String())) + } else { + s, ok = unquote(b.String()) + } if !ok { return nil, io.ErrUnexpectedEOF } diff --git a/parser/parser_test.go b/parser/parser_test.go index 7123e53bf..35556515d 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -48,6 +48,26 @@ func TestParseFileFrom(t *testing.T) { expected []Command err error }{ + { + "FROM \"FOO BAR \"", + []Command{{Name: "model", Args: "FOO BAR "}}, + nil, + }, + { + "FROM \"FOO BAR\"\nPARAMETER param1 value1", + []Command{{Name: "model", Args: "FOO BAR"}, {Name: "param1", Args: "value1"}}, + nil, + }, + { + "FROM FOOO BAR ", + []Command{{Name: "model", Args: "FOOO BAR"}}, + nil, + }, + { + "FROM /what/is/the path ", + []Command{{Name: "model", Args: "/what/is/the path"}}, + nil, + }, { "FROM foo", []Command{{Name: "model", Args: "foo"}}, @@ -86,6 +106,11 @@ func TestParseFileFrom(t *testing.T) { []Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}}, nil, }, + { + "PARAMETER what the \nFROM lemons make lemonade ", + []Command{{Name: "what", Args: "the "}, {Name: "model", Args: "lemons make lemonade"}}, + nil, + }, } for _, c := range cases { From 9bd00041fa1c82881299f34a5950f9edc2a7e66c Mon Sep 17 00:00:00 2001 From: Josh Yan Date: Thu, 27 Jun 2024 11:18:38 -0700 Subject: [PATCH 29/54] trim all params --- parser/parser.go | 11 ++--------- parser/parser_test.go | 4 ++-- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/parser/parser.go b/parser/parser.go index fa60ebc0f..7f566da4e 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -124,8 +124,7 @@ func ParseFile(r io.Reader) (*File, error) { case stateComment, stateNil: // pass case stateValue: - s, ok := unquote(b.String()) - + s, ok := unquote(strings.TrimSpace(b.String())) if !ok || isSpace(r) { if _, err := b.WriteRune(r); err != nil { return nil, err @@ -159,13 +158,7 @@ func ParseFile(r io.Reader) (*File, error) { case stateComment, stateNil: // pass; nothing to flush case stateValue: - var s string - var ok bool - if cmd.Name == "model" { - s, ok = unquote(strings.TrimSpace(b.String())) - } else { - s, ok = unquote(b.String()) - } + s, ok := unquote(strings.TrimSpace(b.String())) if !ok { return nil, io.ErrUnexpectedEOF } diff --git a/parser/parser_test.go b/parser/parser_test.go index 35556515d..3dc592239 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -108,7 +108,7 @@ func TestParseFileFrom(t *testing.T) { }, { "PARAMETER what the \nFROM lemons make lemonade ", - []Command{{Name: "what", Args: "the "}, {Name: "model", Args: "lemons make lemonade"}}, + []Command{{Name: "what", Args: "the"}, {Name: "model", Args: "lemons make lemonade"}}, nil, }, } @@ -424,7 +424,7 @@ func TestParseFileParameters(t *testing.T) { "mirostat_eta 1.0": {"mirostat_eta", "1.0"}, "penalize_newline true": {"penalize_newline", "true"}, "stop ### User:": {"stop", "### User:"}, - "stop ### User: ": {"stop", "### User: "}, + "stop ### User: ": {"stop", "### User:"}, "stop \"### User:\"": {"stop", "### User:"}, "stop \"### User: \"": {"stop", "### User: "}, "stop \"\"\"### User:\"\"\"": {"stop", "### User:"}, From de2163dafd19b5ba2bed3d459354179662cc524d Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 27 Jun 2024 10:52:25 -0700 Subject: [PATCH 30/54] gemma2 graph --- llm/ggml.go | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/llm/ggml.go b/llm/ggml.go index d0d0b6ddc..cfead450d 100644 --- a/llm/ggml.go +++ b/llm/ggml.go @@ -366,9 +366,18 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui 4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16), ) } - case "gemma": - fullOffload = 4 * batch * (embedding + vocab) - partialOffload = 4*batch*(2*embedding+vocab+1) + embedding*vocab*105/128 + case "gemma", "gemma2": + fullOffload = max( + 4*batch*(embedding+vocab), + 4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads), + ) + + partialOffload = max( + 4*embedding*batch+embedding*vocab*105/128+4*vocab*batch, + 4*batch*(2*embedding+1+2*embeddingHeadsK*heads+context+context*heads)+ + 4*embeddingHeadsK*context*8+ + embedding*embeddingHeadsK*heads*9/16, + ) case "command-r": fullOffload = max( 4*batch*(embedding+vocab), From 6d4219083c56ec4b031f0fda67e9ef2c09ad9888 Mon Sep 17 00:00:00 2001 From: royjhan <65097070+royjhan@users.noreply.github.com> Date: Fri, 28 Jun 2024 09:58:14 -0700 Subject: [PATCH 31/54] Update docs (#5312) --- docs/openai.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/openai.md b/docs/openai.md index 59e7d6405..81b967eb7 100644 --- a/docs/openai.md +++ b/docs/openai.md @@ -104,7 +104,6 @@ curl http://localhost:11434/v1/chat/completions \ #### Notes -- `finish_reason` will always be `stop` - `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached ## Models From b910fa90101038d09ca9cbbea16701831fafaffb Mon Sep 17 00:00:00 2001 From: royjhan <65097070+royjhan@users.noreply.github.com> Date: Fri, 28 Jun 2024 11:30:16 -0700 Subject: [PATCH 32/54] Ollama Show: Check for Projector Type (#5307) * Check exists projtype * Maintain Ordering --- cmd/cmd.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index 909e8e4b2..debb39218 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -672,11 +672,17 @@ func ShowHandler(cmd *cobra.Command, args []string) error { projectorData := [][]string{ {"arch", "clip"}, {"parameters", format.HumanNumber(uint64(resp.ProjectorInfo["general.parameter_count"].(float64)))}, - {"projector type", resp.ProjectorInfo["clip.projector_type"].(string)}, - {"embedding length", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.embedding_length"].(float64))}, - {"projection dimensionality", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.projection_dim"].(float64))}, } + if projectorType, ok := resp.ProjectorInfo["clip.projector_type"]; ok { + projectorData = append(projectorData, []string{"projector type", projectorType.(string)}) + } + + projectorData = append(projectorData, + []string{"embedding length", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.embedding_length"].(float64))}, + []string{"projection dimensionality", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.projection_dim"].(float64))}, + ) + mainTableData = append(mainTableData, []string{"Projector"}, []string{renderSubTable(projectorData, false)}, From 5f034f5b63cab3a5eb61104118727b088cceea21 Mon Sep 17 00:00:00 2001 From: royjhan <65097070+royjhan@users.noreply.github.com> Date: Fri, 28 Jun 2024 13:15:52 -0700 Subject: [PATCH 33/54] Include Show Info in Interactive (#5342) --- cmd/cmd.go | 24 +++++++++++------------- cmd/interactive.go | 10 +--------- 2 files changed, 12 insertions(+), 22 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index debb39218..c898c7db6 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -624,13 +624,13 @@ func ShowHandler(cmd *cobra.Command, args []string) error { return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified") } - if flagsSet == 1 { - req := api.ShowRequest{Name: args[0]} - resp, err := client.Show(cmd.Context(), &req) - if err != nil { - return err - } + req := api.ShowRequest{Name: args[0]} + resp, err := client.Show(cmd.Context(), &req) + if err != nil { + return err + } + if flagsSet == 1 { switch showType { case "license": fmt.Println(resp.License) @@ -647,12 +647,12 @@ func ShowHandler(cmd *cobra.Command, args []string) error { return nil } - req := api.ShowRequest{Name: args[0]} - resp, err := client.Show(cmd.Context(), &req) - if err != nil { - return err - } + showInfo(resp) + return nil +} + +func showInfo(resp *api.ShowResponse) { arch := resp.ModelInfo["general.architecture"].(string) modelData := [][]string{ @@ -711,8 +711,6 @@ func ShowHandler(cmd *cobra.Command, args []string) error { } table.Render() - - return nil } func renderSubTable(data [][]string, file bool) string { diff --git a/cmd/interactive.go b/cmd/interactive.go index 0a2f429b6..9214f2db5 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -404,15 +404,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { switch args[1] { case "info": - fmt.Println("Model details:") - if len(resp.Details.Families) > 0 { - fmt.Printf("Family %s\n", strings.Join(resp.Details.Families, ", ")) - } else if resp.Details.Family != "" { - fmt.Printf("Family %s\n", resp.Details.Family) - } - fmt.Printf("Parameter Size %s\n", resp.Details.ParameterSize) - fmt.Printf("Quantization Level %s\n", resp.Details.QuantizationLevel) - fmt.Println("") + showInfo(resp) case "license": if resp.License == "" { fmt.Println("No license was specified for this model.") From aae56abb7cc96b8495a1c761a08b92cfd136d9d2 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Fri, 28 Jun 2024 13:15:57 -0700 Subject: [PATCH 34/54] Document concurrent behavior and settings --- docs/faq.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docs/faq.md b/docs/faq.md index b50a3138c..841f1d13d 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -257,3 +257,17 @@ If you wish to override the `OLLAMA_KEEP_ALIVE` setting, use the `keep_alive` AP ## How do I manage the maximum number of requests the Ollama server can queue? If too many requests are sent to the server, it will respond with a 503 error indicating the server is overloaded. You can adjust how many requests may be queue by setting `OLLAMA_MAX_QUEUE`. + +## How does Ollama handle concurrent requests? + +Ollama supports two levels of concurrent processing. If your system has sufficient available memory (system memory when using CPU inference, or VRAM for GPU inference) then multiple models can be loaded at the same time. For a given model, if there is sufficient available memory when the model is loaded, it is configured to allow parallel request processing. + +If there is insufficient available memory to load a new model request while one or more models are already loaded, all new requests will be queued until the new model can be loaded. As prior models become idle, one or more will be unloaded to make room for the new model. Queued requests will be processed in order. When using GPU inference new models must be able to completely fit in VRAM to allow concurrent model loads. + +Parallel request processing for a given model results in increasing the context size by the number of parallel requests. For example, a 2K context with 4 parallel requests will result in an 8K context and additional memory allocation. + +The following server settings may be used to adjust how Ollama handles concurrent requests: + +- `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory. The default is 3 * the number of GPUs or 3 for CPU inference. +- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory. +- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512 From 717f7229eb4f9220d4070aae617923950643d327 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Fri, 28 Jun 2024 19:39:31 -0700 Subject: [PATCH 35/54] Do not shift context for sliding window models (#5368) * Do not shift context for sliding window models * truncate prompt > 2/3 tokens * only target gemma2 --- llm/ext_server/server.cpp | 46 +++++++++++++++++++++++++++++++-------- 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/llm/ext_server/server.cpp b/llm/ext_server/server.cpp index 492126a4f..3bc012521 100644 --- a/llm/ext_server/server.cpp +++ b/llm/ext_server/server.cpp @@ -1650,26 +1650,41 @@ struct llama_server_context } slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); + char buf[256]; + llama_model_meta_val_str(model, "general.architecture", buf, 256); + bool gemma2 = strcmp(buf, "gemma2") == 0; + + int32_t truncate_at = slot.n_ctx; + + // truncate at 2/3 of the context length for gemma2 models + // as they do not support context shifts (from the sliding window implementation). + // this way, prompts that almost fit the context length can still generate a full + // response without a sudden stop from hitting the context limit + if (gemma2) { + truncate_at = 2 * slot.n_ctx / 3; + } + // if input prompt is too big, truncate it, if group attention self-extend is disabled - if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) + if (slot.ga_n == 1 && slot.n_prompt_tokens >= truncate_at) { const int n_left = slot.n_ctx - slot.params.n_keep; - const int n_block_size = n_left / 2; - const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; + const int n_shift = n_left / 2; + const int n_erase = slot.n_prompt_tokens - slot.params.n_keep - n_shift; std::vector new_tokens( prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep); new_tokens.insert( new_tokens.end(), - prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, + prompt_tokens.begin() + slot.params.n_keep + n_erase, prompt_tokens.end()); - LOG_VERBOSE("input truncated", { - {"n_ctx", slot.n_ctx}, - {"n_keep", slot.params.n_keep}, - {"n_left", n_left}, - {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, + LOG_INFO("input truncated", { + {"n_ctx", slot.n_ctx}, + {"n_keep", slot.params.n_keep}, + {"n_left", n_left}, + {"n_shift", n_shift}, + {"n_erase", n_erase}, }); slot.truncated = true; prompt_tokens = new_tokens; @@ -1678,6 +1693,19 @@ struct llama_server_context GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); } + // Models with sliding window attention do not work with context shifts, so + // limit their prediction to the context length + if (gemma2) { + int32_t limit = slot.n_ctx - slot.n_prompt_tokens; + slot.n_predict = limit; + slot.params.n_predict = limit; + LOG_INFO("model does not support sliding window, limiting generation", { + {"n_ctx", slot.n_ctx}, + {"n_prompt_tokens", slot.n_prompt_tokens}, + {"n_predict", slot.n_predict} + }); + } + if (!slot.params.cache_prompt) { llama_sampling_reset(slot.ctx_sampling); From c1218199cfe82eda35f5e4a8031eee28f01ebf75 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Sat, 29 Jun 2024 16:22:49 -0700 Subject: [PATCH 36/54] Update api.md --- docs/api.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api.md b/docs/api.md index 107b5211f..c577bb1a5 100644 --- a/docs/api.md +++ b/docs/api.md @@ -26,7 +26,7 @@ All durations are returned in nanoseconds. ### Streaming responses -Certain endpoints stream responses as JSON objects and can optional return non-streamed responses. +Certain endpoints stream responses as JSON objects. Streaming can be disabled by providing `{"stream": false}` for these endpoints. ## Generate a completion From 27402cb7a28555a3efcaa5af054b1ce2d18e5442 Mon Sep 17 00:00:00 2001 From: Eduard Date: Mon, 1 Jul 2024 03:48:51 +0200 Subject: [PATCH 37/54] Update gpu.md (#5382) Runs fine on a NVIDIA GeForce GTX 1050 Ti --- docs/gpu.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/gpu.md b/docs/gpu.md index 55c41c9de..80f276c3b 100644 --- a/docs/gpu.md +++ b/docs/gpu.md @@ -18,7 +18,7 @@ Check your compute compatibility to see if your card is supported: | | Quadro | `RTX 8000` `RTX 6000` `RTX 5000` `RTX 4000` | | 7.0 | NVIDIA | `TITAN V` `V100` `Quadro GV100` | | 6.1 | NVIDIA TITAN | `TITAN Xp` `TITAN X` | -| | GeForce GTX | `GTX 1080 Ti` `GTX 1080` `GTX 1070 Ti` `GTX 1070` `GTX 1060` `GTX 1050` | +| | GeForce GTX | `GTX 1080 Ti` `GTX 1080` `GTX 1070 Ti` `GTX 1070` `GTX 1060` `GTX 1050 Ti` `GTX 1050` | | | Quadro | `P6000` `P5200` `P4200` `P3200` `P5000` `P4000` `P3000` `P2200` `P2000` `P1000` `P620` `P600` `P500` `P520` | | | Tesla | `P40` `P4` | | 6.0 | NVIDIA | `Tesla P100` `Quadro GP100` | From 1963c00201958da7165a40f9d2f22b28e11be718 Mon Sep 17 00:00:00 2001 From: RAPID ARCHITECT <126218667+rapidarchitect@users.noreply.github.com> Date: Sun, 30 Jun 2024 21:00:57 -0500 Subject: [PATCH 38/54] Update README.md (#5214) * Update README.md Added Mesop example to web & desktop * Update README.md --------- Co-authored-by: Jeffrey Morgan --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 72ed8fa5e..62f5cd65c 100644 --- a/README.md +++ b/README.md @@ -292,6 +292,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama) - [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS) - [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama) +- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama) ### Terminal From 26e4e66faff20a94bb8fee9ec2bc3e17a07fb19e Mon Sep 17 00:00:00 2001 From: Josh Yan Date: Mon, 1 Jul 2024 09:43:49 -0700 Subject: [PATCH 39/54] updated parsefile test --- parser/parser_test.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/parser/parser_test.go b/parser/parser_test.go index 3dc592239..171bd4206 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -22,7 +22,13 @@ ADAPTER adapter1 LICENSE MIT PARAMETER param1 value1 PARAMETER param2 value2 -TEMPLATE template1 +TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|> + +{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|> + +{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|> + +{{ .Response }}<|eot_id|>""" ` reader := strings.NewReader(input) @@ -36,7 +42,7 @@ TEMPLATE template1 {Name: "license", Args: "MIT"}, {Name: "param1", Args: "value1"}, {Name: "param2", Args: "value2"}, - {Name: "template", Args: "template1"}, + {Name: "template", Args: "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|>"}, } assert.Equal(t, expectedCommands, modelfile.Commands) From cff3f44f4a4097de864d70d9a95f31c62e8ecdfa Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Mon, 1 Jul 2024 09:43:59 -0700 Subject: [PATCH 40/54] Fix case for NumCtx --- server/sched.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/server/sched.go b/server/sched.go index 87da1db47..71b535ae2 100644 --- a/server/sched.go +++ b/server/sched.go @@ -23,7 +23,7 @@ type LlmRequest struct { ctx context.Context //nolint:containedctx model *Model opts api.Options - origNumCTX int // Track the initial ctx request + origNumCtx int // Track the initial ctx request sessionDuration time.Duration successCh chan *runnerRef errCh chan error @@ -118,8 +118,8 @@ func (s *Scheduler) processPending(ctx context.Context) { case pending := <-s.pendingReqCh: // Block other requests until we get this pending request running pending.schedAttempts++ - if pending.origNumCTX == 0 { - pending.origNumCTX = pending.opts.NumCtx + if pending.origNumCtx == 0 { + pending.origNumCtx = pending.opts.NumCtx } if pending.ctx.Err() != nil { @@ -135,7 +135,7 @@ func (s *Scheduler) processPending(ctx context.Context) { } // Keep NumCtx and numParallel in sync if numParallel > 1 { - pending.opts.NumCtx = pending.origNumCTX * numParallel + pending.opts.NumCtx = pending.origNumCtx * numParallel } for { @@ -197,7 +197,7 @@ func (s *Scheduler) processPending(ctx context.Context) { // simplifying assumption of defaultParallel when in CPU mode if numParallel <= 0 { numParallel = defaultParallel - pending.opts.NumCtx = pending.origNumCTX * numParallel + pending.opts.NumCtx = pending.origNumCtx * numParallel } if loadedCount == 0 { @@ -691,7 +691,7 @@ func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numP // First attempt to fit the model into a single GPU for _, p := range numParallelToTry { - req.opts.NumCtx = req.origNumCTX * p + req.opts.NumCtx = req.origNumCtx * p if !envconfig.SchedSpread { for _, g := range sgl { if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok { @@ -709,7 +709,7 @@ func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numP // Now try all the GPUs for _, p := range numParallelToTry { - req.opts.NumCtx = req.origNumCTX * p + req.opts.NumCtx = req.origNumCtx * p if ok, estimatedVRAM = llm.PredictServerFit(sgl, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok { slog.Info("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "parallel", p, "required", format.HumanBytes2(estimatedVRAM)) *numParallel = p From 173b5504381a77b042f3957226a23c0569406aca Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Mon, 1 Jul 2024 09:48:05 -0700 Subject: [PATCH 41/54] Remove default auto from help message This may confuse users thinking "auto" is an acceptable string - it must be numeric --- envconfig/config.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/envconfig/config.go b/envconfig/config.go index 0f0f7f058..c02c4878e 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -85,13 +85,13 @@ func AsMap() map[string]EnvVar { "OLLAMA_HOST": {"OLLAMA_HOST", Host, "IP Address for the ollama server (default 127.0.0.1:11434)"}, "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"}, "OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"}, - "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU (default auto)"}, + "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"}, "OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"}, "OLLAMA_MAX_VRAM": {"OLLAMA_MAX_VRAM", MaxVRAM, "Maximum VRAM"}, "OLLAMA_MODELS": {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"}, "OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"}, "OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"}, - "OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests (default auto)"}, + "OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests"}, "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"}, "OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"}, From 3f0b309ad4c49c0d87839e50fe6a46163902aba0 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 10 Jun 2024 08:47:13 -0700 Subject: [PATCH 42/54] remove ManifestV2 --- server/images.go | 17 +++++------------ server/manifest.go | 20 +++++++++++--------- server/manifest_test.go | 2 +- 3 files changed, 17 insertions(+), 22 deletions(-) diff --git a/server/images.go b/server/images.go index e949fb18a..447a63a69 100644 --- a/server/images.go +++ b/server/images.go @@ -135,13 +135,6 @@ type Message struct { Content string `json:"content"` } -type ManifestV2 struct { - SchemaVersion int `json:"schemaVersion"` - MediaType string `json:"mediaType"` - Config *Layer `json:"config"` - Layers []*Layer `json:"layers"` -} - type ConfigV2 struct { ModelFormat string `json:"model_format"` ModelFamily string `json:"model_family"` @@ -160,7 +153,7 @@ type RootFS struct { DiffIDs []string `json:"diff_ids"` } -func GetManifest(mp ModelPath) (*ManifestV2, string, error) { +func GetManifest(mp ModelPath) (*Manifest, string, error) { fp, err := mp.GetManifestPath() if err != nil { return nil, "", err @@ -170,7 +163,7 @@ func GetManifest(mp ModelPath) (*ManifestV2, string, error) { return nil, "", err } - var manifest *ManifestV2 + var manifest *Manifest bts, err := os.ReadFile(fp) if err != nil { @@ -822,7 +815,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error { mp := ParseModelPath(name) - var manifest *ManifestV2 + var manifest *Manifest var err error var noprune string @@ -929,7 +922,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu return nil } -func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*ManifestV2, error) { +func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) { requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag) headers := make(http.Header) @@ -940,7 +933,7 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptio } defer resp.Body.Close() - var m *ManifestV2 + var m *Manifest if err := json.NewDecoder(resp.Body).Decode(&m); err != nil { return nil, err } diff --git a/server/manifest.go b/server/manifest.go index 61dd1ab4e..726bb48d8 100644 --- a/server/manifest.go +++ b/server/manifest.go @@ -14,7 +14,10 @@ import ( ) type Manifest struct { - ManifestV2 + SchemaVersion int `json:"schemaVersion"` + MediaType string `json:"mediaType"` + Config *Layer `json:"config"` + Layers []*Layer `json:"layers"` filepath string fi os.FileInfo @@ -66,7 +69,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) { p := filepath.Join(manifests, n.Filepath()) - var m ManifestV2 + var m Manifest f, err := os.Open(p) if err != nil { return nil, err @@ -83,12 +86,11 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) { return nil, err } - return &Manifest{ - ManifestV2: m, - filepath: p, - fi: fi, - digest: fmt.Sprintf("%x", sha256sum.Sum(nil)), - }, nil + m.filepath = p + m.fi = fi + m.digest = fmt.Sprintf("%x", sha256sum.Sum(nil)) + + return &m, nil } func WriteManifest(name model.Name, config *Layer, layers []*Layer) error { @@ -108,7 +110,7 @@ func WriteManifest(name model.Name, config *Layer, layers []*Layer) error { } defer f.Close() - m := ManifestV2{ + m := Manifest{ SchemaVersion: 2, MediaType: "application/vnd.docker.distribution.manifest.v2+json", Config: config, diff --git a/server/manifest_test.go b/server/manifest_test.go index ceee31d88..ca6c3d2e9 100644 --- a/server/manifest_test.go +++ b/server/manifest_test.go @@ -25,7 +25,7 @@ func createManifest(t *testing.T, path, name string) { } defer f.Close() - if err := json.NewEncoder(f).Encode(ManifestV2{}); err != nil { + if err := json.NewEncoder(f).Encode(Manifest{}); err != nil { t.Fatal(err) } } From 58e3fff311f9e7abec20cdfe20fa43958e447aeb Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 10 Jun 2024 14:54:42 -0700 Subject: [PATCH 43/54] rename templates to template --- server/images.go | 26 ++- server/model.go | 4 +- server/prompt.go | 18 +- server/prompt_test.go | 15 +- server/routes.go | 26 ++- {templates => template}/alfred.gotmpl | 0 {templates => template}/alpaca.gotmpl | 0 {templates => template}/chatml.gotmpl | 0 {templates => template}/chatqa.gotmpl | 0 .../codellama-70b-instruct.gotmpl | 0 .../falcon-instruct.gotmpl | 0 {templates => template}/gemma-instruct.gotmpl | 0 .../granite-instruct.gotmpl | 0 {templates => template}/index.json | 0 {templates => template}/llama2-chat.gotmpl | 0 .../llama3-instruct.gotmpl | 0 {templates => template}/magicoder.gotmpl | 0 .../mistral-instruct.gotmpl | 0 {templates => template}/openchat.gotmpl | 0 {templates => template}/phi-3.gotmpl | 0 {templates => template}/solar-instruct.gotmpl | 0 .../starcoder2-instruct.gotmpl | 0 template/template.go | 158 ++++++++++++++++++ template/template_test.go | 89 ++++++++++ .../testdata/templates.jsonl | 0 {templates => template}/vicuna.gotmpl | 0 {templates => template}/zephyr.gotmpl | 0 templates/template.go | 70 -------- templates/template_test.go | 59 ------- 29 files changed, 301 insertions(+), 164 deletions(-) rename {templates => template}/alfred.gotmpl (100%) rename {templates => template}/alpaca.gotmpl (100%) rename {templates => template}/chatml.gotmpl (100%) rename {templates => template}/chatqa.gotmpl (100%) rename {templates => template}/codellama-70b-instruct.gotmpl (100%) rename {templates => template}/falcon-instruct.gotmpl (100%) rename {templates => template}/gemma-instruct.gotmpl (100%) rename {templates => template}/granite-instruct.gotmpl (100%) rename {templates => template}/index.json (100%) rename {templates => template}/llama2-chat.gotmpl (100%) rename {templates => template}/llama3-instruct.gotmpl (100%) rename {templates => template}/magicoder.gotmpl (100%) rename {templates => template}/mistral-instruct.gotmpl (100%) rename {templates => template}/openchat.gotmpl (100%) rename {templates => template}/phi-3.gotmpl (100%) rename {templates => template}/solar-instruct.gotmpl (100%) rename {templates => template}/starcoder2-instruct.gotmpl (100%) create mode 100644 template/template.go create mode 100644 template/template_test.go rename {templates => template}/testdata/templates.jsonl (100%) rename {templates => template}/vicuna.gotmpl (100%) rename {templates => template}/zephyr.gotmpl (100%) delete mode 100644 templates/template.go delete mode 100644 templates/template_test.go diff --git a/server/images.go b/server/images.go index 447a63a69..65ed51c76 100644 --- a/server/images.go +++ b/server/images.go @@ -28,6 +28,7 @@ import ( "github.com/ollama/ollama/format" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/parser" + "github.com/ollama/ollama/template" "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" @@ -48,12 +49,13 @@ type Model struct { ParentModel string AdapterPaths []string ProjectorPaths []string - Template string System string License []string Digest string Options map[string]interface{} Messages []Message + + Template *template.Template } func (m *Model) IsEmbedding() bool { @@ -82,10 +84,10 @@ func (m *Model) String() string { }) } - if m.Template != "" { + if m.Template != nil { modelfile.Commands = append(modelfile.Commands, parser.Command{ Name: "template", - Args: m.Template, + Args: m.Template.String(), }) } @@ -191,8 +193,7 @@ func GetModel(name string) (*Model, error) { Name: mp.GetFullTagname(), ShortName: mp.GetShortTagname(), Digest: digest, - Template: "{{ .Prompt }}", - License: []string{}, + Template: template.DefaultTemplate, } filename, err := GetBlobsPath(manifest.Config.Digest) @@ -228,13 +229,17 @@ func GetModel(name string) (*Model, error) { model.AdapterPaths = append(model.AdapterPaths, filename) case "application/vnd.ollama.image.projector": model.ProjectorPaths = append(model.ProjectorPaths, filename) - case "application/vnd.ollama.image.template": + case "application/vnd.ollama.image.prompt", + "application/vnd.ollama.image.template": bts, err := os.ReadFile(filename) if err != nil { return nil, err } - model.Template = string(bts) + model.Template, err = template.Parse(string(bts)) + if err != nil { + return nil, err + } case "application/vnd.ollama.image.system": bts, err := os.ReadFile(filename) if err != nil { @@ -242,13 +247,6 @@ func GetModel(name string) (*Model, error) { } model.System = string(bts) - case "application/vnd.ollama.image.prompt": - bts, err := os.ReadFile(filename) - if err != nil { - return nil, err - } - - model.Template = string(bts) case "application/vnd.ollama.image.params": params, err := os.Open(filename) if err != nil { diff --git a/server/model.go b/server/model.go index d56e641ba..6abb5b392 100644 --- a/server/model.go +++ b/server/model.go @@ -16,7 +16,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/convert" "github.com/ollama/ollama/llm" - "github.com/ollama/ollama/templates" + "github.com/ollama/ollama/template" "github.com/ollama/ollama/types/model" ) @@ -258,7 +258,7 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) { for _, layer := range layers { if s := layer.GGML.KV().ChatTemplate(); s != "" { - if t, err := templates.NamedTemplate(s); err != nil { + if t, err := template.Named(s); err != nil { slog.Debug("template detection", "error", err) } else { tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template") diff --git a/server/prompt.go b/server/prompt.go index 604e69717..bfc319a50 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -4,10 +4,11 @@ import ( "fmt" "log/slog" "strings" - "text/template" + "text/template/parse" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/template" ) // isResponseNode checks if the node contains .Response @@ -53,13 +54,8 @@ func formatTemplateForResponse(tmpl *template.Template, generate bool) { // Prompt renders a prompt from a template. If generate is set to true, // the response and parts of the template following it are not rendered -func Prompt(tmpl, system, prompt, response string, generate bool) (string, error) { - parsed, err := template.New("").Option("missingkey=zero").Parse(tmpl) - if err != nil { - return "", err - } - - formatTemplateForResponse(parsed, generate) +func Prompt(tmpl *template.Template, system, prompt, response string, generate bool) (string, error) { + formatTemplateForResponse(tmpl, generate) vars := map[string]any{ "System": system, @@ -68,14 +64,14 @@ func Prompt(tmpl, system, prompt, response string, generate bool) (string, error } var sb strings.Builder - if err := parsed.Execute(&sb, vars); err != nil { + if err := tmpl.Execute(&sb, vars); err != nil { return "", err } return sb.String(), nil } -func countTokens(tmpl string, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) { +func countTokens(tmpl *template.Template, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) { rendered, err := Prompt(tmpl, system, prompt, response, false) if err != nil { return 0, err @@ -91,7 +87,7 @@ func countTokens(tmpl string, system string, prompt string, response string, enc } // ChatPrompt builds up a prompt from a series of messages, truncating based on context window size -func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) { +func ChatPrompt(tmpl *template.Template, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) { type prompt struct { System string Prompt string diff --git a/server/prompt_test.go b/server/prompt_test.go index a7e18a70f..7df58d0bd 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/template" ) func TestPrompt(t *testing.T) { @@ -61,7 +62,12 @@ func TestPrompt(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - got, err := Prompt(tc.template, tc.system, tc.prompt, tc.response, tc.generate) + tmpl, err := template.Parse(tc.template) + if err != nil { + t.Fatal(err) + } + + got, err := Prompt(tmpl, tc.system, tc.prompt, tc.response, tc.generate) if err != nil { t.Errorf("error = %v", err) } @@ -192,7 +198,12 @@ func TestChatPrompt(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode) + tmpl, err := template.Parse(tc.template) + if err != nil { + t.Fatal(err) + } + + got, err := ChatPrompt(tmpl, tc.messages, tc.window, encode) if err != nil { t.Errorf("error = %v", err) } diff --git a/server/routes.go b/server/routes.go index 76ead072f..d8a4a67e7 100644 --- a/server/routes.go +++ b/server/routes.go @@ -31,6 +31,7 @@ import ( "github.com/ollama/ollama/llm" "github.com/ollama/ollama/openai" "github.com/ollama/ollama/parser" + "github.com/ollama/ollama/template" "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" @@ -161,6 +162,12 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } + tmpl, err := template.Parse(req.Template) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + checkpointLoaded := time.Now() var prompt string @@ -169,7 +176,11 @@ func (s *Server) GenerateHandler(c *gin.Context) { prompt = req.Prompt case req.Prompt != "": if req.Template == "" { - req.Template = model.Template + model.Template, err = template.Parse(req.Template) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } } if req.System == "" { @@ -187,7 +198,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { sb.WriteString(req.Prompt) - p, err := Prompt(req.Template, req.System, sb.String(), "", true) + p, err := Prompt(tmpl, req.System, sb.String(), "", true) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -242,7 +253,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { resp.LoadDuration = checkpointLoaded.Sub(checkpointStart) if !req.Raw { - p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false) + p, err := Prompt(tmpl, req.System, req.Prompt, generated.String(), false) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -680,7 +691,10 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { } if req.Template != "" { - m.Template = req.Template + m.Template, err = template.Parse(req.Template) + if err != nil { + return nil, err + } } msgs := make([]api.Message, 0) @@ -701,7 +715,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { resp := &api.ShowResponse{ License: strings.Join(m.License, "\n"), System: m.System, - Template: m.Template, + Template: m.Template.String(), Details: modelDetails, Messages: msgs, ModifiedAt: manifest.fi.ModTime(), @@ -1246,7 +1260,7 @@ func (s *Server) ProcessHandler(c *gin.Context) { } // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model -func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) { +func chatPrompt(ctx context.Context, runner *runnerRef, template *template.Template, messages []api.Message, numCtx int) (string, error) { encode := func(s string) ([]int, error) { return runner.llama.Tokenize(ctx, s) } diff --git a/templates/alfred.gotmpl b/template/alfred.gotmpl similarity index 100% rename from templates/alfred.gotmpl rename to template/alfred.gotmpl diff --git a/templates/alpaca.gotmpl b/template/alpaca.gotmpl similarity index 100% rename from templates/alpaca.gotmpl rename to template/alpaca.gotmpl diff --git a/templates/chatml.gotmpl b/template/chatml.gotmpl similarity index 100% rename from templates/chatml.gotmpl rename to template/chatml.gotmpl diff --git a/templates/chatqa.gotmpl b/template/chatqa.gotmpl similarity index 100% rename from templates/chatqa.gotmpl rename to template/chatqa.gotmpl diff --git a/templates/codellama-70b-instruct.gotmpl b/template/codellama-70b-instruct.gotmpl similarity index 100% rename from templates/codellama-70b-instruct.gotmpl rename to template/codellama-70b-instruct.gotmpl diff --git a/templates/falcon-instruct.gotmpl b/template/falcon-instruct.gotmpl similarity index 100% rename from templates/falcon-instruct.gotmpl rename to template/falcon-instruct.gotmpl diff --git a/templates/gemma-instruct.gotmpl b/template/gemma-instruct.gotmpl similarity index 100% rename from templates/gemma-instruct.gotmpl rename to template/gemma-instruct.gotmpl diff --git a/templates/granite-instruct.gotmpl b/template/granite-instruct.gotmpl similarity index 100% rename from templates/granite-instruct.gotmpl rename to template/granite-instruct.gotmpl diff --git a/templates/index.json b/template/index.json similarity index 100% rename from templates/index.json rename to template/index.json diff --git a/templates/llama2-chat.gotmpl b/template/llama2-chat.gotmpl similarity index 100% rename from templates/llama2-chat.gotmpl rename to template/llama2-chat.gotmpl diff --git a/templates/llama3-instruct.gotmpl b/template/llama3-instruct.gotmpl similarity index 100% rename from templates/llama3-instruct.gotmpl rename to template/llama3-instruct.gotmpl diff --git a/templates/magicoder.gotmpl b/template/magicoder.gotmpl similarity index 100% rename from templates/magicoder.gotmpl rename to template/magicoder.gotmpl diff --git a/templates/mistral-instruct.gotmpl b/template/mistral-instruct.gotmpl similarity index 100% rename from templates/mistral-instruct.gotmpl rename to template/mistral-instruct.gotmpl diff --git a/templates/openchat.gotmpl b/template/openchat.gotmpl similarity index 100% rename from templates/openchat.gotmpl rename to template/openchat.gotmpl diff --git a/templates/phi-3.gotmpl b/template/phi-3.gotmpl similarity index 100% rename from templates/phi-3.gotmpl rename to template/phi-3.gotmpl diff --git a/templates/solar-instruct.gotmpl b/template/solar-instruct.gotmpl similarity index 100% rename from templates/solar-instruct.gotmpl rename to template/solar-instruct.gotmpl diff --git a/templates/starcoder2-instruct.gotmpl b/template/starcoder2-instruct.gotmpl similarity index 100% rename from templates/starcoder2-instruct.gotmpl rename to template/starcoder2-instruct.gotmpl diff --git a/template/template.go b/template/template.go new file mode 100644 index 000000000..d15f7156f --- /dev/null +++ b/template/template.go @@ -0,0 +1,158 @@ +package template + +import ( + "bytes" + "embed" + "encoding/json" + "errors" + "io" + "math" + "slices" + "strings" + "sync" + "text/template" + "text/template/parse" + + "github.com/agnivade/levenshtein" + "golang.org/x/exp/maps" +) + +//go:embed index.json +var indexBytes []byte + +//go:embed *.gotmpl +var templatesFS embed.FS + +var templatesOnce = sync.OnceValues(func() ([]*named, error) { + var templates []*named + if err := json.Unmarshal(indexBytes, &templates); err != nil { + return nil, err + } + + for _, t := range templates { + bts, err := templatesFS.ReadFile(t.Name + ".gotmpl") + if err != nil { + return nil, err + } + + // normalize line endings + t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n")) + } + + return templates, nil +}) + +type named struct { + Name string `json:"name"` + Template string `json:"template"` + Bytes []byte +} + +func (t named) Reader() io.Reader { + return bytes.NewReader(t.Bytes) +} + +func Named(s string) (*named, error) { + templates, err := templatesOnce() + if err != nil { + return nil, err + } + + var template *named + score := math.MaxInt + for _, t := range templates { + if s := levenshtein.ComputeDistance(s, t.Template); s < score { + score = s + template = t + } + } + + if score < 100 { + return template, nil + } + + return nil, errors.New("no matching template found") +} + +type Template struct { + *template.Template + raw string +} + +func (t *Template) String() string { + return t.raw +} + +var DefaultTemplate, _ = Parse("{{ .Prompt }}") + +func Parse(s string) (*Template, error) { + t, err := template.New("").Option("missingkey=zero").Parse(s) + if err != nil { + return nil, err + } + + return &Template{Template: t, raw: s}, nil +} + +func (t *Template) Vars() []string { + var vars []string + for _, n := range t.Tree.Root.Nodes { + vars = append(vars, parseNode(n)...) + } + + set := make(map[string]struct{}) + for _, n := range vars { + set[strings.ToLower(n)] = struct{}{} + } + + vars = maps.Keys(set) + slices.Sort(vars) + return vars +} + +func parseNode(n parse.Node) []string { + switch n := n.(type) { + case *parse.ActionNode: + return parseNode(n.Pipe) + case *parse.IfNode: + names := parseNode(n.Pipe) + names = append(names, parseNode(n.List)...) + if n.ElseList != nil { + names = append(names, parseNode(n.ElseList)...) + } + return names + case *parse.RangeNode: + names := parseNode(n.Pipe) + names = append(names, parseNode(n.List)...) + if n.ElseList != nil { + names = append(names, parseNode(n.ElseList)...) + } + return names + case *parse.WithNode: + names := parseNode(n.Pipe) + names = append(names, parseNode(n.List)...) + if n.ElseList != nil { + names = append(names, parseNode(n.ElseList)...) + } + return names + case *parse.PipeNode: + var names []string + for _, c := range n.Cmds { + for _, a := range c.Args { + names = append(names, parseNode(a)...) + } + } + return names + case *parse.ListNode: + var names []string + for _, n := range n.Nodes { + names = append(names, parseNode(n)...) + } + + return names + case *parse.FieldNode: + return n.Ident + } + + return nil +} diff --git a/template/template_test.go b/template/template_test.go new file mode 100644 index 000000000..e5405bdb4 --- /dev/null +++ b/template/template_test.go @@ -0,0 +1,89 @@ +package template + +import ( + "bufio" + "bytes" + "encoding/json" + "io" + "os" + "path/filepath" + "slices" + "testing" + "text/template" + + "github.com/ollama/ollama/llm" +) + +func TestNamed(t *testing.T) { + f, err := os.Open(filepath.Join("testdata", "templates.jsonl")) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + var ss map[string]string + if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil { + t.Fatal(err) + } + + for k, v := range ss { + t.Run(k, func(t *testing.T) { + kv := llm.KV{"tokenizer.chat_template": v} + s := kv.ChatTemplate() + r, err := Named(s) + if err != nil { + t.Fatal(err) + } + + if r.Name != k { + t.Errorf("expected %q, got %q", k, r.Name) + } + + var b bytes.Buffer + if _, err := io.Copy(&b, r.Reader()); err != nil { + t.Fatal(err) + } + + tmpl, err := template.New(s).Parse(b.String()) + if err != nil { + t.Fatal(err) + } + + if tmpl.Tree.Root.String() == "" { + t.Errorf("empty %s template", k) + } + }) + } + } +} + +func TestParse(t *testing.T) { + cases := []struct { + template string + capabilities []string + }{ + {"{{ .Prompt }}", []string{"prompt"}}, + {"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}}, + {"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}}, + {"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "system", "tools"}}, + {"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}}, + {"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}}, + {"{{ .Prompt }} {{ .Suffix }}", []string{"prompt", "suffix"}}, + } + + for _, tt := range cases { + t.Run("", func(t *testing.T) { + tmpl, err := Parse(tt.template) + if err != nil { + t.Fatal(err) + } + + vars := tmpl.Vars() + if !slices.Equal(tt.capabilities, vars) { + t.Errorf("expected %v, got %v", tt.capabilities, vars) + } + }) + } +} diff --git a/templates/testdata/templates.jsonl b/template/testdata/templates.jsonl similarity index 100% rename from templates/testdata/templates.jsonl rename to template/testdata/templates.jsonl diff --git a/templates/vicuna.gotmpl b/template/vicuna.gotmpl similarity index 100% rename from templates/vicuna.gotmpl rename to template/vicuna.gotmpl diff --git a/templates/zephyr.gotmpl b/template/zephyr.gotmpl similarity index 100% rename from templates/zephyr.gotmpl rename to template/zephyr.gotmpl diff --git a/templates/template.go b/templates/template.go deleted file mode 100644 index 72bd69e9d..000000000 --- a/templates/template.go +++ /dev/null @@ -1,70 +0,0 @@ -package templates - -import ( - "bytes" - "embed" - "encoding/json" - "errors" - "io" - "math" - "sync" - - "github.com/agnivade/levenshtein" -) - -//go:embed index.json -var indexBytes []byte - -//go:embed *.gotmpl -var templatesFS embed.FS - -var templatesOnce = sync.OnceValues(func() ([]*Template, error) { - var templates []*Template - if err := json.Unmarshal(indexBytes, &templates); err != nil { - return nil, err - } - - for _, t := range templates { - bts, err := templatesFS.ReadFile(t.Name + ".gotmpl") - if err != nil { - return nil, err - } - - // normalize line endings - t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n")) - } - - return templates, nil -}) - -type Template struct { - Name string `json:"name"` - Template string `json:"template"` - Bytes []byte -} - -func (t Template) Reader() io.Reader { - return bytes.NewReader(t.Bytes) -} - -func NamedTemplate(s string) (*Template, error) { - templates, err := templatesOnce() - if err != nil { - return nil, err - } - - var template *Template - score := math.MaxInt - for _, t := range templates { - if s := levenshtein.ComputeDistance(s, t.Template); s < score { - score = s - template = t - } - } - - if score < 100 { - return template, nil - } - - return nil, errors.New("no matching template found") -} diff --git a/templates/template_test.go b/templates/template_test.go deleted file mode 100644 index 61bc78374..000000000 --- a/templates/template_test.go +++ /dev/null @@ -1,59 +0,0 @@ -package templates - -import ( - "bufio" - "bytes" - "encoding/json" - "io" - "os" - "path/filepath" - "testing" - "text/template" - - "github.com/ollama/ollama/llm" -) - -func TestKVChatTemplate(t *testing.T) { - f, err := os.Open(filepath.Join("testdata", "templates.jsonl")) - if err != nil { - t.Fatal(err) - } - defer f.Close() - - scanner := bufio.NewScanner(f) - for scanner.Scan() { - var ss map[string]string - if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil { - t.Fatal(err) - } - - for k, v := range ss { - t.Run(k, func(t *testing.T) { - kv := llm.KV{"tokenizer.chat_template": v} - s := kv.ChatTemplate() - r, err := NamedTemplate(s) - if err != nil { - t.Fatal(err) - } - - if r.Name != k { - t.Errorf("expected %q, got %q", k, r.Name) - } - - var b bytes.Buffer - if _, err := io.Copy(&b, r.Reader()); err != nil { - t.Fatal(err) - } - - tmpl, err := template.New(s).Parse(b.String()) - if err != nil { - t.Fatal(err) - } - - if tmpl.Tree.Root.String() == "" { - t.Errorf("empty %s template", k) - } - }) - } - } -} From a30915bde166b2f392a0ff72c61c9ac53189a962 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 11 Jun 2024 14:03:42 -0700 Subject: [PATCH 44/54] add capabilities --- server/images.go | 20 ++++++++++++++++++-- server/routes.go | 8 ++++---- template/template_test.go | 8 ++++---- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/server/images.go b/server/images.go index 65ed51c76..5cd0a7a53 100644 --- a/server/images.go +++ b/server/images.go @@ -34,6 +34,10 @@ import ( "github.com/ollama/ollama/version" ) +type Capability string + +const CapabilityCompletion = Capability("completion") + type registryOptions struct { Insecure bool Username string @@ -58,8 +62,20 @@ type Model struct { Template *template.Template } -func (m *Model) IsEmbedding() bool { - return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert") +func (m *Model) Has(caps ...Capability) bool { + for _, cap := range caps { + switch cap { + case CapabilityCompletion: + if slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert") { + return false + } + default: + slog.Error("unknown capability", "capability", cap) + return false + } + } + + return true } func (m *Model) String() string { diff --git a/server/routes.go b/server/routes.go index d8a4a67e7..8ca6dcc89 100644 --- a/server/routes.go +++ b/server/routes.go @@ -122,8 +122,8 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - if model.IsEmbedding() { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support generate"}) + if !model.Has(CapabilityCompletion) { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support generate", req.Model)}) return } @@ -1308,8 +1308,8 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - if model.IsEmbedding() { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support chat"}) + if !model.Has(CapabilityCompletion) { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support chat", req.Model)}) return } diff --git a/template/template_test.go b/template/template_test.go index e5405bdb4..eda4634f4 100644 --- a/template/template_test.go +++ b/template/template_test.go @@ -61,8 +61,8 @@ func TestNamed(t *testing.T) { func TestParse(t *testing.T) { cases := []struct { - template string - capabilities []string + template string + vars []string }{ {"{{ .Prompt }}", []string{"prompt"}}, {"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}}, @@ -81,8 +81,8 @@ func TestParse(t *testing.T) { } vars := tmpl.Vars() - if !slices.Equal(tt.capabilities, vars) { - t.Errorf("expected %v, got %v", tt.capabilities, vars) + if !slices.Equal(tt.vars, vars) { + t.Errorf("expected %v, got %v", tt.vars, vars) } }) } From da8e2a04479f96ad9c57eaf25ed26b79b239b05c Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 14 Jun 2024 14:57:49 -0700 Subject: [PATCH 45/54] use kvs to detect embedding models --- server/images.go | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/server/images.go b/server/images.go index 5cd0a7a53..a62991f16 100644 --- a/server/images.go +++ b/server/images.go @@ -66,7 +66,21 @@ func (m *Model) Has(caps ...Capability) bool { for _, cap := range caps { switch cap { case CapabilityCompletion: - if slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert") { + f, err := os.Open(m.ModelPath) + if err != nil { + slog.Error("couldn't open model file", "error", err) + continue + } + defer f.Close() + + // TODO(mxyng): decode the GGML into model to avoid doing this multiple times + ggml, _, err := llm.DecodeGGML(f, 0) + if err != nil { + slog.Error("couldn't decode ggml", "error", err) + continue + } + + if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok { return false } default: From 7e571f95f0306f90e4f754e34df96ebc36f93626 Mon Sep 17 00:00:00 2001 From: Josh Yan Date: Mon, 1 Jul 2024 11:07:48 -0700 Subject: [PATCH 46/54] trimspace test case --- parser/parser_test.go | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/parser/parser_test.go b/parser/parser_test.go index 171bd4206..2b5c4c888 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -48,6 +48,39 @@ TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|> assert.Equal(t, expectedCommands, modelfile.Commands) } +func TestParseFileTrimSpace(t *testing.T) { + input := ` +FROM " model 1" +ADAPTER adapter3 +LICENSE "MIT " +PARAMETER param1 value1 +PARAMETER param2 value2 +TEMPLATE """ {{ if .System }}<|start_header_id|>system<|end_header_id|> + +{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|> + +{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|> + +{{ .Response }}<|eot_id|> """ +` + + reader := strings.NewReader(input) + + modelfile, err := ParseFile(reader) + require.NoError(t, err) + + expectedCommands := []Command{ + {Name: "model", Args: " model 1"}, + {Name: "adapter", Args: "adapter3"}, + {Name: "license", Args: "MIT "}, + {Name: "param1", Args: "value1"}, + {Name: "param2", Args: "value2"}, + {Name: "template", Args: " {{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|> "}, + } + + assert.Equal(t, expectedCommands, modelfile.Commands) +} + func TestParseFileFrom(t *testing.T) { var cases = []struct { input string From 88bcd79bb9a4b2baa739efe2ccabcbcf3c89bdb5 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Sun, 30 Jun 2024 11:10:40 -0700 Subject: [PATCH 47/54] err on insecure path --- server/model.go | 8 +++----- server/model_test.go | 24 ++++++++++++++++++++++-- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/server/model.go b/server/model.go index d56e641ba..7d5957a18 100644 --- a/server/model.go +++ b/server/model.go @@ -11,7 +11,6 @@ import ( "net/http" "os" "path/filepath" - "strings" "github.com/ollama/ollama/api" "github.com/ollama/ollama/convert" @@ -91,12 +90,11 @@ func extractFromZipFile(p string, file *os.File, fn func(api.ProgressResponse)) fn(api.ProgressResponse{Status: "unpacking model metadata"}) for _, f := range r.File { - n := filepath.Join(p, f.Name) - if !strings.HasPrefix(n, p) { - slog.Warn("skipped extracting file outside of context", "name", f.Name) - continue + if !filepath.IsLocal(f.Name) { + return fmt.Errorf("%w: %s", zip.ErrInsecurePath, f.Name) } + n := filepath.Join(p, f.Name) if err := os.MkdirAll(filepath.Dir(n), 0o750); err != nil { return err } diff --git a/server/model_test.go b/server/model_test.go index c3023eb2b..a383b7e72 100644 --- a/server/model_test.go +++ b/server/model_test.go @@ -3,10 +3,12 @@ package server import ( "archive/zip" "bytes" + "errors" "io" "os" "path/filepath" "slices" + "strings" "testing" "github.com/ollama/ollama/api" @@ -39,13 +41,31 @@ func TestExtractFromZipFile(t *testing.T) { cases := []struct { name string expect []string + err error }{ { name: "good", expect: []string{"good"}, }, { - name: filepath.Join("..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "bad"), + name: strings.Join([]string{"path", "..", "to", "good"}, string(os.PathSeparator)), + expect: []string{filepath.Join("to", "good")}, + }, + { + name: strings.Join([]string{"path", "..", "to", "..", "good"}, string(os.PathSeparator)), + expect: []string{"good"}, + }, + { + name: strings.Join([]string{"path", "to", "..", "..", "good"}, string(os.PathSeparator)), + expect: []string{"good"}, + }, + { + name: strings.Join([]string{"..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "bad"}, string(os.PathSeparator)), + err: zip.ErrInsecurePath, + }, + { + name: strings.Join([]string{"path", "..", "..", "to", "bad"}, string(os.PathSeparator)), + err: zip.ErrInsecurePath, }, } @@ -55,7 +75,7 @@ func TestExtractFromZipFile(t *testing.T) { defer f.Close() tempDir := t.TempDir() - if err := extractFromZipFile(tempDir, f, func(api.ProgressResponse) {}); err != nil { + if err := extractFromZipFile(tempDir, f, func(api.ProgressResponse) {}); !errors.Is(err, tt.err) { t.Fatal(err) } From 33a65e3ba3ad5666d6ba8430efbccfa6d642d1de Mon Sep 17 00:00:00 2001 From: Josh Yan Date: Mon, 1 Jul 2024 16:04:13 -0700 Subject: [PATCH 48/54] error --- llm/server.go | 3 +++ llm/status.go | 1 + 2 files changed, 4 insertions(+) diff --git a/llm/server.go b/llm/server.go index 61346069e..8b63cfbd5 100644 --- a/llm/server.go +++ b/llm/server.go @@ -560,6 +560,9 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { if s.status != nil && s.status.LastErrMsg != "" { msg = s.status.LastErrMsg } + if strings.Contains(msg, "unknown model") { + return fmt.Errorf("this model is not supported by your version of Ollama. You may need to upgrade") + } return fmt.Errorf("llama runner process has terminated: %v %s", err, msg) default: } diff --git a/llm/status.go b/llm/status.go index 8a49bd55a..0f56b7f99 100644 --- a/llm/status.go +++ b/llm/status.go @@ -25,6 +25,7 @@ var errorPrefixes = []string{ "CUDA error", "cudaMalloc failed", "\"ERR\"", + "architecture", } func (w *StatusWriter) Write(b []byte) (int, error) { From 4f67b39d262b1997aa96c47585f1d8e8443d0f90 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Tue, 2 Jul 2024 09:22:17 -0700 Subject: [PATCH 49/54] Centos 7 EOL broke mirrors As of July 1st 2024: Could not resolve host: mirrorlist.centos.org This is expected due to EOL dates. --- scripts/rh_linux_deps.sh | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/scripts/rh_linux_deps.sh b/scripts/rh_linux_deps.sh index ed60e4304..81648d68e 100644 --- a/scripts/rh_linux_deps.sh +++ b/scripts/rh_linux_deps.sh @@ -6,10 +6,21 @@ set -ex MACHINE=$(uname -m) if grep -i "centos" /etc/system-release >/dev/null; then + # As of 7/1/2024 mirrorlist.centos.org has been taken offline, so adjust accordingly + sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo + sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo + sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo + # Centos 7 derivatives have too old of a git version to run our generate script # uninstall and ignore failures yum remove -y git yum -y install epel-release centos-release-scl + + # The release packages reinstate the mirrors, undo that again + sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo + sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo + sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo + yum -y install dnf if [ "${MACHINE}" = "x86_64" ]; then yum -y install https://repo.ius.io/ius-release-el7.rpm From 020bd60ab2f156661b072515cd2c27d59b956535 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Tue, 2 Jul 2024 10:23:05 -0700 Subject: [PATCH 50/54] Switch amd container image base to rocky 8 The centos 7 arm mirrors have disappeared due to the EOL 2 days ago, and the vault sed workaround which works for x86 doesn't work for arm. --- Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 98a3ddfd2..b2c5c4a2f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -70,12 +70,12 @@ RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64 RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh -FROM --platform=linux/arm64 centos:7 AS cpu-builder-arm64 +FROM --platform=linux/arm64 rockylinux:8 AS cpu-builder-arm64 ARG CMAKE_VERSION ARG GOLANG_VERSION COPY ./scripts/rh_linux_deps.sh / RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh -ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH +ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH COPY --from=llm-code / /go/src/github.com/ollama/ollama/ ARG OLLAMA_CUSTOM_CPU_DEFS ARG CGO_CFLAGS From 996bb1b85e0c1b3ae64246a50ea412dc2a2e30d8 Mon Sep 17 00:00:00 2001 From: royjhan <65097070+royjhan@users.noreply.github.com> Date: Tue, 2 Jul 2024 11:50:56 -0700 Subject: [PATCH 51/54] OpenAI: /v1/models and /v1/models/{model} compatibility (#5007) * OpenAI v1 models * Refactor Writers * Add Test Co-Authored-By: Attila Kerekes * Credit Co-Author Co-Authored-By: Attila Kerekes <439392+keriati@users.noreply.github.com> * Empty List Testing * Use Namespace for Ownedby * Update Test * Add back envconfig * v1/models docs * Use ModelName Parser * Test Names * Remove Docs * Clean Up * Test name Co-authored-by: Jeffrey Morgan * Add Middleware for Chat and List * Testing Cleanup * Test with Fatal * Add functionality to chat test * OpenAI: /v1/models/{model} compatibility (#5028) * Retrieve Model * OpenAI Delete Model * Retrieve Middleware * Remove Delete from Branch * Update Test * Middleware Test File * Function name * Cleanup * Test Update * Test Update --------- Co-authored-by: Attila Kerekes <439392+keriati@users.noreply.github.com> Co-authored-by: Jeffrey Morgan --- api/types.go | 7 ++ docs/openai.md | 1 + openai/openai.go | 163 ++++++++++++++++++++++++++++++++++++---- openai/openai_test.go | 170 ++++++++++++++++++++++++++++++++++++++++++ server/routes.go | 4 +- server/routes_test.go | 56 ++++++++++++++ 6 files changed, 387 insertions(+), 14 deletions(-) create mode 100644 openai/openai_test.go diff --git a/api/types.go b/api/types.go index 95ed5d37e..428281ba6 100644 --- a/api/types.go +++ b/api/types.go @@ -345,6 +345,13 @@ type ProcessModelResponse struct { SizeVRAM int64 `json:"size_vram"` } +type RetrieveModelResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + OwnedBy string `json:"owned_by"` +} + type TokenResponse struct { Token string `json:"token"` } diff --git a/docs/openai.md b/docs/openai.md index 81b967eb7..9dda05c3a 100644 --- a/docs/openai.md +++ b/docs/openai.md @@ -65,6 +65,7 @@ curl http://localhost:11434/v1/chat/completions \ } ] }' + ``` ## Endpoints diff --git a/openai/openai.go b/openai/openai.go index 706d31aa2..01da44409 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -12,6 +12,7 @@ import ( "github.com/gin-gonic/gin" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/types/model" ) type Error struct { @@ -85,6 +86,18 @@ type ChatCompletionChunk struct { Choices []ChunkChoice `json:"choices"` } +type Model struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + OwnedBy string `json:"owned_by"` +} + +type ListCompletion struct { + Object string `json:"object"` + Data []Model `json:"data"` +} + func NewError(code int, message string) ErrorResponse { var etype string switch code { @@ -145,7 +158,33 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk { } } -func fromRequest(r ChatCompletionRequest) api.ChatRequest { +func toListCompletion(r api.ListResponse) ListCompletion { + var data []Model + for _, m := range r.Models { + data = append(data, Model{ + Id: m.Name, + Object: "model", + Created: m.ModifiedAt.Unix(), + OwnedBy: model.ParseName(m.Name).Namespace, + }) + } + + return ListCompletion{ + Object: "list", + Data: data, + } +} + +func toModel(r api.ShowResponse, m string) Model { + return Model{ + Id: m, + Object: "model", + Created: r.ModifiedAt.Unix(), + OwnedBy: model.ParseName(m).Namespace, + } +} + +func fromChatRequest(r ChatCompletionRequest) api.ChatRequest { var messages []api.Message for _, msg := range r.Messages { messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content}) @@ -208,13 +247,26 @@ func fromRequest(r ChatCompletionRequest) api.ChatRequest { } } -type writer struct { - stream bool - id string +type BaseWriter struct { gin.ResponseWriter } -func (w *writer) writeError(code int, data []byte) (int, error) { +type ChatWriter struct { + stream bool + id string + BaseWriter +} + +type ListWriter struct { + BaseWriter +} + +type RetrieveWriter struct { + BaseWriter + model string +} + +func (w *BaseWriter) writeError(code int, data []byte) (int, error) { var serr api.StatusError err := json.Unmarshal(data, &serr) if err != nil { @@ -230,7 +282,7 @@ func (w *writer) writeError(code int, data []byte) (int, error) { return len(data), nil } -func (w *writer) writeResponse(data []byte) (int, error) { +func (w *ChatWriter) writeResponse(data []byte) (int, error) { var chatResponse api.ChatResponse err := json.Unmarshal(data, &chatResponse) if err != nil { @@ -270,7 +322,7 @@ func (w *writer) writeResponse(data []byte) (int, error) { return len(data), nil } -func (w *writer) Write(data []byte) (int, error) { +func (w *ChatWriter) Write(data []byte) (int, error) { code := w.ResponseWriter.Status() if code != http.StatusOK { return w.writeError(code, data) @@ -279,7 +331,92 @@ func (w *writer) Write(data []byte) (int, error) { return w.writeResponse(data) } -func Middleware() gin.HandlerFunc { +func (w *ListWriter) writeResponse(data []byte) (int, error) { + var listResponse api.ListResponse + err := json.Unmarshal(data, &listResponse) + if err != nil { + return 0, err + } + + w.ResponseWriter.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse)) + if err != nil { + return 0, err + } + + return len(data), nil +} + +func (w *ListWriter) Write(data []byte) (int, error) { + code := w.ResponseWriter.Status() + if code != http.StatusOK { + return w.writeError(code, data) + } + + return w.writeResponse(data) +} + +func (w *RetrieveWriter) writeResponse(data []byte) (int, error) { + var showResponse api.ShowResponse + err := json.Unmarshal(data, &showResponse) + if err != nil { + return 0, err + } + + // retrieve completion + w.ResponseWriter.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model)) + if err != nil { + return 0, err + } + + return len(data), nil +} + +func (w *RetrieveWriter) Write(data []byte) (int, error) { + code := w.ResponseWriter.Status() + if code != http.StatusOK { + return w.writeError(code, data) + } + + return w.writeResponse(data) +} + +func ListMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + w := &ListWriter{ + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + } + + c.Writer = w + + c.Next() + } +} + +func RetrieveMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) + return + } + + c.Request.Body = io.NopCloser(&b) + + // response writer + w := &RetrieveWriter{ + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + model: c.Param("model"), + } + + c.Writer = w + + c.Next() + } +} + +func ChatMiddleware() gin.HandlerFunc { return func(c *gin.Context) { var req ChatCompletionRequest err := c.ShouldBindJSON(&req) @@ -294,17 +431,17 @@ func Middleware() gin.HandlerFunc { } var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(fromRequest(req)); err != nil { + if err := json.NewEncoder(&b).Encode(fromChatRequest(req)); err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) return } c.Request.Body = io.NopCloser(&b) - w := &writer{ - ResponseWriter: c.Writer, - stream: req.Stream, - id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)), + w := &ChatWriter{ + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + stream: req.Stream, + id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)), } c.Writer = w diff --git a/openai/openai_test.go b/openai/openai_test.go new file mode 100644 index 000000000..1f335b965 --- /dev/null +++ b/openai/openai_test.go @@ -0,0 +1,170 @@ +package openai + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/ollama/ollama/api" + "github.com/stretchr/testify/assert" +) + +func TestMiddleware(t *testing.T) { + type testCase struct { + Name string + Method string + Path string + TestPath string + Handler func() gin.HandlerFunc + Endpoint func(c *gin.Context) + Setup func(t *testing.T, req *http.Request) + Expected func(t *testing.T, resp *httptest.ResponseRecorder) + } + + testCases := []testCase{ + { + Name: "chat handler", + Method: http.MethodPost, + Path: "/api/chat", + TestPath: "/api/chat", + Handler: ChatMiddleware, + Endpoint: func(c *gin.Context) { + var chatReq api.ChatRequest + if err := c.ShouldBindJSON(&chatReq); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"}) + return + } + + userMessage := chatReq.Messages[0].Content + var assistantMessage string + + switch userMessage { + case "Hello": + assistantMessage = "Hello!" + default: + assistantMessage = "I'm not sure how to respond to that." + } + + c.JSON(http.StatusOK, api.ChatResponse{ + Message: api.Message{ + Role: "assistant", + Content: assistantMessage, + }, + }) + }, + Setup: func(t *testing.T, req *http.Request) { + body := ChatCompletionRequest{ + Model: "test-model", + Messages: []Message{{Role: "user", Content: "Hello"}}, + } + + bodyBytes, _ := json.Marshal(body) + + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + }, + Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { + var chatResp ChatCompletion + if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil { + t.Fatal(err) + } + + if chatResp.Object != "chat.completion" { + t.Fatalf("expected chat.completion, got %s", chatResp.Object) + } + + if chatResp.Choices[0].Message.Content != "Hello!" { + t.Fatalf("expected Hello!, got %s", chatResp.Choices[0].Message.Content) + } + }, + }, + { + Name: "list handler", + Method: http.MethodGet, + Path: "/api/tags", + TestPath: "/api/tags", + Handler: ListMiddleware, + Endpoint: func(c *gin.Context) { + c.JSON(http.StatusOK, api.ListResponse{ + Models: []api.ListModelResponse{ + { + Name: "Test Model", + }, + }, + }) + }, + Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { + var listResp ListCompletion + if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil { + t.Fatal(err) + } + + if listResp.Object != "list" { + t.Fatalf("expected list, got %s", listResp.Object) + } + + if len(listResp.Data) != 1 { + t.Fatalf("expected 1, got %d", len(listResp.Data)) + } + + if listResp.Data[0].Id != "Test Model" { + t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id) + } + }, + }, + { + Name: "retrieve model", + Method: http.MethodGet, + Path: "/api/show/:model", + TestPath: "/api/show/test-model", + Handler: RetrieveMiddleware, + Endpoint: func(c *gin.Context) { + c.JSON(http.StatusOK, api.ShowResponse{ + ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC), + }) + }, + Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { + var retrieveResp Model + if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil { + t.Fatal(err) + } + + if retrieveResp.Object != "model" { + t.Fatalf("Expected object to be model, got %s", retrieveResp.Object) + } + + if retrieveResp.Id != "test-model" { + t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id) + } + }, + }, + } + + gin.SetMode(gin.TestMode) + router := gin.New() + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + router = gin.New() + router.Use(tc.Handler()) + router.Handle(tc.Method, tc.Path, tc.Endpoint) + req, _ := http.NewRequest(tc.Method, tc.TestPath, nil) + + if tc.Setup != nil { + tc.Setup(t, req) + } + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + assert.Equal(t, http.StatusOK, resp.Code) + + tc.Expected(t, resp) + }) + } +} diff --git a/server/routes.go b/server/routes.go index 76ead072f..ad2364507 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1039,7 +1039,9 @@ func (s *Server) GenerateRoutes() http.Handler { r.GET("/api/ps", s.ProcessHandler) // Compatibility endpoints - r.POST("/v1/chat/completions", openai.Middleware(), s.ChatHandler) + r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler) + r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler) + r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler) for _, method := range []string{http.MethodGet, http.MethodHead} { r.Handle(method, "/", func(c *gin.Context) { diff --git a/server/routes_test.go b/server/routes_test.go index 5a5c0fbba..50eaf7e97 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -20,6 +20,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/llm" + "github.com/ollama/ollama/openai" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" @@ -105,6 +106,24 @@ func Test_Routes(t *testing.T) { assert.Empty(t, len(modelList.Models)) }, }, + { + Name: "openai empty list", + Method: http.MethodGet, + Path: "/v1/models", + Expected: func(t *testing.T, resp *http.Response) { + contentType := resp.Header.Get("Content-Type") + assert.Equal(t, "application/json", contentType) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + var modelList openai.ListCompletion + err = json.Unmarshal(body, &modelList) + require.NoError(t, err) + + assert.Equal(t, "list", modelList.Object) + assert.Empty(t, modelList.Data) + }, + }, { Name: "Tags Handler (yes tags)", Method: http.MethodGet, @@ -128,6 +147,25 @@ func Test_Routes(t *testing.T) { assert.Equal(t, "test-model:latest", modelList.Models[0].Name) }, }, + { + Name: "openai list models with tags", + Method: http.MethodGet, + Path: "/v1/models", + Expected: func(t *testing.T, resp *http.Response) { + contentType := resp.Header.Get("Content-Type") + assert.Equal(t, "application/json", contentType) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + var modelList openai.ListCompletion + err = json.Unmarshal(body, &modelList) + require.NoError(t, err) + + assert.Len(t, modelList.Data, 1) + assert.Equal(t, "test-model:latest", modelList.Data[0].Id) + assert.Equal(t, "library", modelList.Data[0].OwnedBy) + }, + }, { Name: "Create Model Handler", Method: http.MethodPost, @@ -216,6 +254,24 @@ func Test_Routes(t *testing.T) { assert.InDelta(t, 0, showResp.ModelInfo["general.parameter_count"], 1e-9, "Parameter count should be 0") }, }, + { + Name: "openai retrieve model handler", + Method: http.MethodGet, + Path: "/v1/models/show-model", + Expected: func(t *testing.T, resp *http.Response) { + contentType := resp.Header.Get("Content-Type") + assert.Equal(t, "application/json", contentType) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + var retrieveResp api.RetrieveModelResponse + err = json.Unmarshal(body, &retrieveResp) + require.NoError(t, err) + + assert.Equal(t, "show-model", retrieveResp.Id) + assert.Equal(t, "library", retrieveResp.OwnedBy) + }, + }, } t.Setenv("OLLAMA_MODELS", t.TempDir()) From 69c04eecc4b969149e43d6941f06a7d60dc5d191 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Tue, 2 Jul 2024 12:46:14 -0700 Subject: [PATCH 52/54] Add windows radeon concurreny note --- docs/faq.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/faq.md b/docs/faq.md index 841f1d13d..574112461 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -266,8 +266,10 @@ If there is insufficient available memory to load a new model request while one Parallel request processing for a given model results in increasing the context size by the number of parallel requests. For example, a 2K context with 4 parallel requests will result in an 8K context and additional memory allocation. -The following server settings may be used to adjust how Ollama handles concurrent requests: +The following server settings may be used to adjust how Ollama handles concurrent requests on most platforms: - `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory. The default is 3 * the number of GPUs or 3 for CPU inference. - `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory. - `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512 + +Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM. \ No newline at end of file From d626b99b547c43e57390cec90ba2ae01adf0f429 Mon Sep 17 00:00:00 2001 From: royjhan <65097070+royjhan@users.noreply.github.com> Date: Tue, 2 Jul 2024 16:01:45 -0700 Subject: [PATCH 53/54] OpenAI: v1/completions compatibility (#5209) * OpenAI v1 models * Refactor Writers * Add Test Co-Authored-By: Attila Kerekes * Credit Co-Author Co-Authored-By: Attila Kerekes <439392+keriati@users.noreply.github.com> * Empty List Testing * Use Namespace for Ownedby * Update Test * Add back envconfig * v1/models docs * Use ModelName Parser * Test Names * Remove Docs * Clean Up * Test name Co-authored-by: Jeffrey Morgan * Add Middleware for Chat and List * Completions Endpoint * Testing Cleanup * Test with Fatal * Add functionality to chat test * Rename function * float types * type cleanup * cleaning * more cleaning * Extra test cases * merge conflicts * merge conflicts * merge conflicts * merge conflicts * cleaning * cleaning --------- Co-authored-by: Attila Kerekes <439392+keriati@users.noreply.github.com> Co-authored-by: Jeffrey Morgan --- openai/openai.go | 223 +++++++++++++++++++++++++++++++++++++++++- openai/openai_test.go | 132 ++++++++++++++++++++++++- server/routes.go | 1 + 3 files changed, 353 insertions(+), 3 deletions(-) diff --git a/openai/openai.go b/openai/openai.go index 01da44409..f1e75bf21 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -43,6 +43,12 @@ type ChunkChoice struct { FinishReason *string `json:"finish_reason"` } +type CompleteChunkChoice struct { + Text string `json:"text"` + Index int `json:"index"` + FinishReason *string `json:"finish_reason"` +} + type Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` @@ -86,6 +92,39 @@ type ChatCompletionChunk struct { Choices []ChunkChoice `json:"choices"` } +// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int +type CompletionRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + FrequencyPenalty float32 `json:"frequency_penalty"` + MaxTokens *int `json:"max_tokens"` + PresencePenalty float32 `json:"presence_penalty"` + Seed *int `json:"seed"` + Stop any `json:"stop"` + Stream bool `json:"stream"` + Temperature *float32 `json:"temperature"` + TopP float32 `json:"top_p"` +} + +type Completion struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Choices []CompleteChunkChoice `json:"choices"` + Usage Usage `json:"usage,omitempty"` +} + +type CompletionChunk struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Choices []CompleteChunkChoice `json:"choices"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` +} + type Model struct { Id string `json:"id"` Object string `json:"object"` @@ -158,6 +197,52 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk { } } +func toCompletion(id string, r api.GenerateResponse) Completion { + return Completion{ + Id: id, + Object: "text_completion", + Created: r.CreatedAt.Unix(), + Model: r.Model, + SystemFingerprint: "fp_ollama", + Choices: []CompleteChunkChoice{{ + Text: r.Response, + Index: 0, + FinishReason: func(reason string) *string { + if len(reason) > 0 { + return &reason + } + return nil + }(r.DoneReason), + }}, + Usage: Usage{ + // TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count + PromptTokens: r.PromptEvalCount, + CompletionTokens: r.EvalCount, + TotalTokens: r.PromptEvalCount + r.EvalCount, + }, + } +} + +func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk { + return CompletionChunk{ + Id: id, + Object: "text_completion", + Created: time.Now().Unix(), + Model: r.Model, + SystemFingerprint: "fp_ollama", + Choices: []CompleteChunkChoice{{ + Text: r.Response, + Index: 0, + FinishReason: func(reason string) *string { + if len(reason) > 0 { + return &reason + } + return nil + }(r.DoneReason), + }}, + } +} + func toListCompletion(r api.ListResponse) ListCompletion { var data []Model for _, m := range r.Models { @@ -195,7 +280,7 @@ func fromChatRequest(r ChatCompletionRequest) api.ChatRequest { switch stop := r.Stop.(type) { case string: options["stop"] = []string{stop} - case []interface{}: + case []any: var stops []string for _, s := range stop { if str, ok := s.(string); ok { @@ -247,6 +332,52 @@ func fromChatRequest(r ChatCompletionRequest) api.ChatRequest { } } +func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) { + options := make(map[string]any) + + switch stop := r.Stop.(type) { + case string: + options["stop"] = []string{stop} + case []string: + options["stop"] = stop + default: + if r.Stop != nil { + return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", r.Stop) + } + } + + if r.MaxTokens != nil { + options["num_predict"] = *r.MaxTokens + } + + if r.Temperature != nil { + options["temperature"] = *r.Temperature * 2.0 + } else { + options["temperature"] = 1.0 + } + + if r.Seed != nil { + options["seed"] = *r.Seed + } + + options["frequency_penalty"] = r.FrequencyPenalty * 2.0 + + options["presence_penalty"] = r.PresencePenalty * 2.0 + + if r.TopP != 0.0 { + options["top_p"] = r.TopP + } else { + options["top_p"] = 1.0 + } + + return api.GenerateRequest{ + Model: r.Model, + Prompt: r.Prompt, + Options: options, + Stream: &r.Stream, + }, nil +} + type BaseWriter struct { gin.ResponseWriter } @@ -257,6 +388,12 @@ type ChatWriter struct { BaseWriter } +type CompleteWriter struct { + stream bool + id string + BaseWriter +} + type ListWriter struct { BaseWriter } @@ -331,6 +468,55 @@ func (w *ChatWriter) Write(data []byte) (int, error) { return w.writeResponse(data) } +func (w *CompleteWriter) writeResponse(data []byte) (int, error) { + var generateResponse api.GenerateResponse + err := json.Unmarshal(data, &generateResponse) + if err != nil { + return 0, err + } + + // completion chunk + if w.stream { + d, err := json.Marshal(toCompleteChunk(w.id, generateResponse)) + if err != nil { + return 0, err + } + + w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") + _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) + if err != nil { + return 0, err + } + + if generateResponse.Done { + _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n")) + if err != nil { + return 0, err + } + } + + return len(data), nil + } + + // completion + w.ResponseWriter.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse)) + if err != nil { + return 0, err + } + + return len(data), nil +} + +func (w *CompleteWriter) Write(data []byte) (int, error) { + code := w.ResponseWriter.Status() + if code != http.StatusOK { + return w.writeError(code, data) + } + + return w.writeResponse(data) +} + func (w *ListWriter) writeResponse(data []byte) (int, error) { var listResponse api.ListResponse err := json.Unmarshal(data, &listResponse) @@ -416,6 +602,41 @@ func RetrieveMiddleware() gin.HandlerFunc { } } +func CompletionsMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + var req CompletionRequest + err := c.ShouldBindJSON(&req) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) + return + } + + var b bytes.Buffer + genReq, err := fromCompleteRequest(req) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) + return + } + + if err := json.NewEncoder(&b).Encode(genReq); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) + return + } + + c.Request.Body = io.NopCloser(&b) + + w := &CompleteWriter{ + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + stream: req.Stream, + id: fmt.Sprintf("cmpl-%d", rand.Intn(999)), + } + + c.Writer = w + + c.Next() + } +} + func ChatMiddleware() gin.HandlerFunc { return func(c *gin.Context) { var req ChatCompletionRequest diff --git a/openai/openai_test.go b/openai/openai_test.go index 1f335b965..4d21382c6 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -3,9 +3,11 @@ package openai import ( "bytes" "encoding/json" + "fmt" "io" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -69,6 +71,8 @@ func TestMiddleware(t *testing.T) { req.Header.Set("Content-Type", "application/json") }, Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusOK, resp.Code) + var chatResp ChatCompletion if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil { t.Fatal(err) @@ -83,6 +87,130 @@ func TestMiddleware(t *testing.T) { } }, }, + { + Name: "completions handler", + Method: http.MethodPost, + Path: "/api/generate", + TestPath: "/api/generate", + Handler: CompletionsMiddleware, + Endpoint: func(c *gin.Context) { + c.JSON(http.StatusOK, api.GenerateResponse{ + Response: "Hello!", + }) + }, + Setup: func(t *testing.T, req *http.Request) { + body := CompletionRequest{ + Model: "test-model", + Prompt: "Hello", + } + + bodyBytes, _ := json.Marshal(body) + + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + }, + Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusOK, resp.Code) + var completionResp Completion + if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil { + t.Fatal(err) + } + + if completionResp.Object != "text_completion" { + t.Fatalf("expected text_completion, got %s", completionResp.Object) + } + + if completionResp.Choices[0].Text != "Hello!" { + t.Fatalf("expected Hello!, got %s", completionResp.Choices[0].Text) + } + }, + }, + { + Name: "completions handler with params", + Method: http.MethodPost, + Path: "/api/generate", + TestPath: "/api/generate", + Handler: CompletionsMiddleware, + Endpoint: func(c *gin.Context) { + var generateReq api.GenerateRequest + if err := c.ShouldBindJSON(&generateReq); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"}) + return + } + + temperature := generateReq.Options["temperature"].(float64) + var assistantMessage string + + switch temperature { + case 1.6: + assistantMessage = "Received temperature of 1.6" + default: + assistantMessage = fmt.Sprintf("Received temperature of %f", temperature) + } + + c.JSON(http.StatusOK, api.GenerateResponse{ + Response: assistantMessage, + }) + }, + Setup: func(t *testing.T, req *http.Request) { + temp := float32(0.8) + body := CompletionRequest{ + Model: "test-model", + Prompt: "Hello", + Temperature: &temp, + } + + bodyBytes, _ := json.Marshal(body) + + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + }, + Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusOK, resp.Code) + var completionResp Completion + if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil { + t.Fatal(err) + } + + if completionResp.Object != "text_completion" { + t.Fatalf("expected text_completion, got %s", completionResp.Object) + } + + if completionResp.Choices[0].Text != "Received temperature of 1.6" { + t.Fatalf("expected Received temperature of 1.6, got %s", completionResp.Choices[0].Text) + } + }, + }, + { + Name: "completions handler with error", + Method: http.MethodPost, + Path: "/api/generate", + TestPath: "/api/generate", + Handler: CompletionsMiddleware, + Endpoint: func(c *gin.Context) { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"}) + }, + Setup: func(t *testing.T, req *http.Request) { + body := CompletionRequest{ + Model: "test-model", + Prompt: "Hello", + } + + bodyBytes, _ := json.Marshal(body) + + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + }, + Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { + if resp.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.Code) + } + + if !strings.Contains(resp.Body.String(), `"invalid request"`) { + t.Fatalf("error was not forwarded") + } + }, + }, { Name: "list handler", Method: http.MethodGet, @@ -99,6 +227,8 @@ func TestMiddleware(t *testing.T) { }) }, Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusOK, resp.Code) + var listResp ListCompletion if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil { t.Fatal(err) @@ -162,8 +292,6 @@ func TestMiddleware(t *testing.T) { resp := httptest.NewRecorder() router.ServeHTTP(resp, req) - assert.Equal(t, http.StatusOK, resp.Code) - tc.Expected(t, resp) }) } diff --git a/server/routes.go b/server/routes.go index 9fe5fcc4e..41c920844 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1054,6 +1054,7 @@ func (s *Server) GenerateRoutes() http.Handler { // Compatibility endpoints r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler) + r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler) r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler) r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler) From 65a5040e09d34b4e4237a4ac1996e2fb2a112bb3 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 2 Jul 2024 16:42:17 -0700 Subject: [PATCH 54/54] fix generate template --- server/routes.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/server/routes.go b/server/routes.go index 41c920844..b14a146c1 100644 --- a/server/routes.go +++ b/server/routes.go @@ -176,11 +176,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { prompt = req.Prompt case req.Prompt != "": if req.Template == "" { - model.Template, err = template.Parse(req.Template) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } + tmpl = model.Template } if req.System == "" {