Compare commits

..

9 Commits

Author SHA1 Message Date
Bruce MacDonald
60f0b7db76 working 2025-01-24 16:51:19 -08:00
Jesse Gross
0d22c0ec1a new runner 2025-01-21 10:54:22 -08:00
Jesse Gross
a34960a516 tensor loading iface 2025-01-21 10:54:16 -08:00
Michael Yang
9722d96db2 docker 2025-01-20 17:06:45 -08:00
Michael Yang
fd868e5a88 default native cuda architecture when possible 2025-01-20 16:19:52 -08:00
Michael Yang
a48bd5bbe7 cpu all 2025-01-20 15:40:05 -08:00
Michael Yang
ffd9914d0a update ci 2025-01-20 09:39:43 -08:00
Michael Yang
9d249bfcab sort devices by type 2025-01-20 09:39:43 -08:00
Michael Yang
8cb7b94c40 next ollama runner
implement llama and mllama model architectures in go using ggml (through
cgo)
2025-01-20 09:39:43 -08:00
442 changed files with 48475 additions and 553409 deletions

View File

@@ -3,9 +3,7 @@ ollama
app
macapp
dist
build
.env
.cache
test_data
.git
llama/build

4
.gitattributes vendored
View File

@@ -15,10 +15,6 @@ ml/backend/**/*.cu linguist-vendored
ml/backend/**/*.cuh linguist-vendored
ml/backend/**/*.m linguist-vendored
ml/backend/**/*.metal linguist-vendored
ml/backend/**/CMakeLists.txt linguist-vendored
llama/build-info.cpp linguist-generated
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.s linguist-generated
* text=auto
*.go text eol=lf

View File

@@ -9,14 +9,6 @@ body:
description: What happened? What did you expect to happen?
validations:
required: true
- type: textarea
id: logs
attributes:
label: Relevant log output
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md#how-to-troubleshoot-issues) for details.
render: shell
validations:
required: false
- type: dropdown
id: os
attributes:

View File

@@ -1,66 +1,31 @@
name: release
env:
ROCM_WINDOWS_URL: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe
MSYS2_URL: https://github.com/msys2/msys2-installer/releases/download/2024-07-27/msys2-x86_64-20240727.exe
on:
push:
tags:
- 'v*'
env:
CGO_CFLAGS: '-O3'
CGO_CXXFLAGS: '-O3'
jobs:
setup-environment:
runs-on: ubuntu-latest
environment: release
outputs:
GOFLAGS: ${{ steps.goflags.outputs.GOFLAGS }}
steps:
- uses: actions/checkout@v4
- name: Set environment
id: goflags
run: |
echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" >>$GITHUB_OUTPUT
darwin-build:
# Full build of the Mac assets
build-darwin:
runs-on: macos-13
environment: release
needs: setup-environment
strategy:
matrix:
os: [darwin]
arch: [amd64, arm64]
env:
GOFLAGS: ${{ needs.setup-environment.outputs.GOFLAGS }}
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: go.mod
- run: |
go build -o dist/ .
- name: Set Version
shell: bash
run: |
echo "VERSION=${GITHUB_REF_NAME#v}" >> $GITHUB_ENV
echo "RELEASE_VERSION=$(echo ${GITHUB_REF_NAME} | cut -f1 -d-)" >> $GITHUB_ENV
- name: key
env:
GOOS: ${{ matrix.os }}
GOARCH: ${{ matrix.arch }}
CGO_ENABLED: 1
CGO_CPPFLAGS: '-mmacosx-version-min=11.3'
- if: matrix.arch == 'amd64'
MACOS_SIGNING_KEY: ${{ secrets.MACOS_SIGNING_KEY }}
MACOS_SIGNING_KEY_PASSWORD: ${{ secrets.MACOS_SIGNING_KEY_PASSWORD }}
run: |
cmake --preset CPU -DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DCMAKE_SYSTEM_PROCESSOR=x86_64 -DCMAKE_OSX_ARCHITECTURES=x86_64
cmake --build --parallel --preset CPU
cmake --install build --component CPU --strip --parallel 8
- uses: actions/upload-artifact@v4
with:
name: build-${{ matrix.os }}-${{ matrix.arch }}
path: dist/*
darwin-sign:
runs-on: macos-13
environment: release
needs: darwin-build
steps:
- uses: actions/checkout@v4
- run: |
echo $MACOS_SIGNING_KEY | base64 --decode > certificate.p12
security create-keychain -p password build.keychain
security default-keychain -s build.keychain
@@ -68,20 +33,11 @@ jobs:
security import certificate.p12 -k build.keychain -P $MACOS_SIGNING_KEY_PASSWORD -T /usr/bin/codesign
security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k password build.keychain
security set-keychain-settings -lut 3600 build.keychain
env:
MACOS_SIGNING_KEY: ${{ secrets.MACOS_SIGNING_KEY }}
MACOS_SIGNING_KEY_PASSWORD: ${{ secrets.MACOS_SIGNING_KEY_PASSWORD }}
- uses: actions/download-artifact@v4
- uses: actions/setup-go@v5
with:
name: build-darwin-amd64
path: dist/darwin-amd64
- uses: actions/download-artifact@v4
with:
name: build-darwin-arm64
path: dist/darwin-arm64
- run: |
export VERSION=${GITHUB_REF_NAME#v}
./scripts/build_darwin.sh sign macapp
go-version-file: go.mod
cache: true
- name: Build Darwin
env:
APPLE_IDENTITY: ${{ secrets.APPLE_IDENTITY }}
APPLE_PASSWORD: ${{ secrets.APPLE_PASSWORD }}
@@ -89,324 +45,485 @@ jobs:
APPLE_ID: ${{ vars.APPLE_ID }}
SDKROOT: /Applications/Xcode_14.1.0.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk
DEVELOPER_DIR: /Applications/Xcode_14.1.0.app/Contents/Developer
run: |
./scripts/build_darwin.sh
- uses: actions/upload-artifact@v4
with:
name: dist-darwin
path: |
dist/Ollama-darwin.zip
dist/ollama-darwin.tgz
dist/ollama-darwin
windows-depends:
strategy:
matrix:
os: [windows]
arch: [amd64]
preset: ['CPU']
include:
- os: windows
arch: amd64
preset: 'CUDA 11'
install: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
cuda-version: '11.3'
- os: windows
arch: amd64
preset: 'CUDA 12'
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
cuda-version: '12.8'
- os: windows
arch: amd64
preset: 'ROCm 6'
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
rocm-version: '6.2'
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
# Windows builds take a long time to both install the dependencies and build, so parallelize
# CPU generation step
generate-windows-cpu:
environment: release
runs-on: windows
env:
GOFLAGS: ${{ needs.setup-environment.outputs.GOFLAGS }}
KEY_CONTAINER: ${{ vars.KEY_CONTAINER }}
steps:
- name: Install system dependencies
run: |
choco install -y --no-progress ccache ninja
ccache -o cache_dir=${{ github.workspace }}\.ccache
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ')
id: cache-install
uses: actions/cache/restore@v4
with:
path: |
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
key: ${{ matrix.install }}
- if: startsWith(matrix.preset, 'CUDA ')
name: Install CUDA ${{ matrix.cuda-version }}
run: |
$ErrorActionPreference = "Stop"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
$subpackages = @("cudart", "nvcc", "cublas", "cublas_dev") | Foreach-Object {"${_}_${{ matrix.cuda-version }}"}
Start-Process -FilePath .\install.exe -ArgumentList (@("-s") + $subpackages) -NoNewWindow -Wait
}
$cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path
echo "$cudaPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- if: startsWith(matrix.preset, 'ROCm')
name: Install ROCm ${{ matrix.rocm-version }}
run: |
$ErrorActionPreference = "Stop"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
Start-Process -FilePath .\install.exe -ArgumentList '-install' -NoNewWindow -Wait
}
$hipPath = (Resolve-Path "C:\Program Files\AMD\ROCm\*").path
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: matrix.preset == 'CPU'
run: |
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
uses: actions/cache/save@v4
with:
path: |
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
key: ${{ matrix.install }}
- uses: actions/checkout@v4
- uses: actions/cache@v4
with:
path: ${{ github.workspace }}\.ccache
key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}
- name: Build target "${{ matrix.preset }}"
- name: Set make jobs default
run: |
Import-Module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -VsInstallPath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
cmake --preset "${{ matrix.preset }}"
cmake --build --parallel --preset "${{ matrix.preset }}"
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8
env:
CMAKE_GENERATOR: Ninja
- uses: actions/upload-artifact@v4
with:
name: depends-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}
path: dist\*
windows-build:
strategy:
matrix:
os: [windows]
arch: [amd64, arm64]
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
needs: [setup-environment]
env:
GOFLAGS: ${{ needs.setup-environment.outputs.GOFLAGS }}
steps:
- name: Install AMD64 system dependencies
if: matrix.arch == 'amd64'
echo "MAKEFLAGS=--jobs=$((Get-ComputerInfo -Property CsProcessors).CsProcessors.NumberOfCores)" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
- name: Set Version
shell: bash
run: echo "VERSION=${GITHUB_REF_NAME#v}" >> $GITHUB_ENV
- name: Add msys paths
run: |
$ErrorActionPreference = "Stop"
Start-Process "C:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang") -NoNewWindow -Wait
echo "C:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "c:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "C:\msys64\clang64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- name: Install ARM64 system dependencies
if: matrix.arch == 'arm64'
- name: Install msys2 tools
run: |
$ErrorActionPreference = "Stop"
Set-ExecutionPolicy Bypass -Scope Process -Force
[System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072
iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1'))
echo "C:\ProgramData\chocolatey\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
choco install -y --no-progress git gzip
echo "C:\Program Files\Git\cmd" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
Invoke-WebRequest -Uri "https://github.com/mstorsjo/llvm-mingw/releases/download/20240619/llvm-mingw-20240619-ucrt-aarch64.zip" -OutFile "${{ runner.temp }}\llvm-mingw-ucrt-aarch64.zip"
Expand-Archive -Path ${{ runner.temp }}\llvm-mingw-ucrt-aarch64.zip -DestinationPath "C:\Program Files\"
$installPath=(Resolve-Path -Path "C:\Program Files\llvm-mingw-*-ucrt-aarch64").path
echo $installPath\bin | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- uses: actions/checkout@v4
Start-Process "c:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang") -NoNewWindow -Wait
- uses: actions/setup-go@v5
with:
go-version-file: go.mod
cache: true
- run: |
go build -o dist/${{ matrix.os }}-${{ matrix.arch }}/ .
- if: matrix.arch == 'arm64'
run: |
Invoke-WebRequest -Uri "https://aka.ms/vs/17/release/vc_redist.arm64.exe" -OutFile "dist\windows-arm64\vc_redist.arm64.exe"
- run: |
$env:VERSION='${{ github.ref_name }}' -Replace "v(.*)", '$1'
& .\scripts\build_windows.ps1 buildApp
env:
VCToolsRedistDir: stub
import-module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -vsinstallpath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -skipautomaticlocation -DevCmdArguments '-arch=x64 -no_logo'
if (!(gcc --version | select-string -quiet clang)) { throw "wrong gcc compiler detected - must be clang" }
make dist
name: make
- uses: actions/upload-artifact@v4
with:
name: build-${{ matrix.os }}-${{ matrix.arch }}
name: generate-windows-cpu
path: |
dist\${{ matrix.os }}-${{ matrix.arch }}\*.exe
dist\${{ matrix.os }}-${{ matrix.arch }}-app.exe
dist/windows-amd64/**
windows-sign:
runs-on: windows-2022
# ROCm generation step
generate-windows-rocm:
environment: release
needs: [windows-depends, windows-build]
runs-on: windows
env:
KEY_CONTAINER: ${{ vars.KEY_CONTAINER }}
steps:
- uses: actions/checkout@v4
- uses: google-github-actions/auth@v2
- name: Set make jobs default
run: |
echo "MAKEFLAGS=--jobs=$((Get-ComputerInfo -Property CsProcessors).CsProcessors.NumberOfCores)" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
- name: Set Version
shell: bash
run: echo "VERSION=${GITHUB_REF_NAME#v}" >> $GITHUB_ENV
- name: Add msys paths
run: |
echo "c:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "C:\msys64\clang64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- name: Install msys2 tools
run: |
Start-Process "c:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang") -NoNewWindow -Wait
- uses: actions/setup-go@v5
with:
project_id: ollama
credentials_json: ${{ secrets.GOOGLE_SIGNING_CREDENTIALS }}
- run: |
go-version-file: go.mod
cache: true
# ROCM installation steps
- name: 'Cache ROCm installer'
id: cache-rocm
uses: actions/cache@v4
with:
path: rocm-install.exe
key: ${{ env.ROCM_WINDOWS_URL }}
- name: 'Conditionally Download ROCm'
if: steps.cache-rocm.outputs.cache-hit != 'true'
run: |
$ErrorActionPreference = "Stop"
Invoke-WebRequest -Uri "https://go.microsoft.com/fwlink/p/?LinkId=323507" -OutFile "${{ runner.temp }}\sdksetup.exe"
Start-Process "${{ runner.temp }}\sdksetup.exe" -ArgumentList @("/q") -NoNewWindow -Wait
Invoke-WebRequest -Uri "https://github.com/GoogleCloudPlatform/kms-integrations/releases/download/cng-v1.0/kmscng-1.0-windows-amd64.zip" -OutFile "${{ runner.temp }}\plugin.zip"
Expand-Archive -Path "${{ runner.temp }}\plugin.zip" -DestinationPath "${{ runner.temp }}\plugin\"
& "${{ runner.temp }}\plugin\*\kmscng.msi" /quiet
echo "${{ vars.OLLAMA_CERT }}" >ollama_inc.crt
- uses: actions/download-artifact@v4
Invoke-WebRequest -Uri "${env:ROCM_WINDOWS_URL}" -OutFile "rocm-install.exe"
- name: 'Install ROCm'
run: |
Start-Process "rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
- name: 'Verify ROCm'
run: |
& 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version
echo "HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path | select -first 1)" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
- name: make rocm runner
run: |
import-module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -vsinstallpath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -skipautomaticlocation -DevCmdArguments '-arch=x64 -no_logo'
if (!(gcc --version | select-string -quiet clang)) { throw "wrong gcc compiler detected - must be clang" }
make help-runners
make dist_rocm
- uses: actions/upload-artifact@v4
with:
pattern: build-windows-*
path: dist\
merge-multiple: true
- uses: actions/download-artifact@v4
name: generate-windows-rocm
path: |
dist/windows-amd64/**
# CUDA generation step
generate-windows-cuda:
environment: release
runs-on: windows
strategy:
matrix:
cuda:
- version: "11.3"
url: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
- version: "12.4"
url: https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_551.61_windows.exe
env:
KEY_CONTAINER: ${{ vars.KEY_CONTAINER }}
steps:
- uses: actions/checkout@v4
- name: Set make jobs default
run: |
echo "MAKEFLAGS=--jobs=$((Get-ComputerInfo -Property CsProcessors).CsProcessors.NumberOfCores)" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
- name: Set Version
shell: bash
run: echo "VERSION=${GITHUB_REF_NAME#v}" >> $GITHUB_ENV
- name: Install msys2
run: |
$msys2_url="https://github.com/msys2/msys2-installer/releases/download/2024-07-27/msys2-x86_64-20240727.exe"
write-host "Downloading msys2"
Invoke-WebRequest -Uri "${msys2_url}" -OutFile "${env:RUNNER_TEMP}\msys2.exe"
write-host "Installing msys2"
Start-Process "${env:RUNNER_TEMP}\msys2.exe" -ArgumentList @("in", "--confirm-command", "--accept-messages", "--root", "C:/msys64") -NoNewWindow -Wait
echo "c:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- name: Install msys2 tools
run: |
Start-Process "c:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang", "make") -NoNewWindow -Wait
echo "C:\msys64\clang64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- name: verify tools
run: |
get-command gcc
gcc --version
get-command make
make --version
- uses: actions/setup-go@v5
with:
pattern: depends-windows-amd64-*
path: dist\windows-amd64\
merge-multiple: true
go-version-file: go.mod
cache: true
# CUDA installation steps
- name: 'Cache CUDA installer'
id: cache-cuda
uses: actions/cache@v4
with:
path: cuda-install.exe
key: ${{ matrix.cuda.url }}
- name: 'Conditionally Download CUDA'
if: steps.cache-cuda.outputs.cache-hit != 'true'
run: |
$ErrorActionPreference = "Stop"
Invoke-WebRequest -Uri "${{ matrix.cuda.url }}" -OutFile "cuda-install.exe"
- name: 'Install CUDA'
run: |
$subpackages = @("cudart", "nvcc", "cublas", "cublas_dev") | foreach-object {"${_}_${{ matrix.cuda.version }}"}
Start-Process "cuda-install.exe" -ArgumentList (@("-s") + $subpackages) -NoNewWindow -Wait
- name: 'Verify CUDA'
run: |
& (resolve-path "c:\Program Files\NVIDIA*\CUDA\v*\bin\nvcc.exe")[0] --version
$cudaPath=((resolve-path "c:\Program Files\NVIDIA*\CUDA\v*\bin\nvcc.exe")[0].path | split-path | split-path)
$cudaVer=($cudaPath | split-path -leaf ) -replace 'v(\d+).(\d+)', '$1_$2'
echo "$cudaPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "CUDA_PATH=$cudaPath" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
echo "CUDA_PATH_V${cudaVer}=$cudaPath" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
echo "CUDA_PATH_VX_Y=CUDA_PATH_V${cudaVer}" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
- name: make cuda runner
run: |
import-module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -vsinstallpath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -skipautomaticlocation -DevCmdArguments '-arch=x64 -no_logo'
if (!(gcc --version | select-string -quiet clang)) { throw "wrong gcc compiler detected - must be clang" }
make dist_cuda_v$(($env:CUDA_PATH | split-path -leaf) -replace 'v(\d+).*', '$1')
- uses: actions/upload-artifact@v4
with:
name: generate-windows-cuda-${{ matrix.cuda.version }}
path: |
dist/windows-amd64/**
# windows arm64 generate, go build, and zip file (no installer)
# Output of this build is aggregated into the final x86 build
# for a unified windows installer
windows-arm64:
runs-on: windows-arm64
environment: release
env:
KEY_CONTAINER: ${{ vars.KEY_CONTAINER }}
steps:
# The current Windows arm64 beta image has effectively zero dev tools installed...
- name: Install git and gzip
run: |
Set-ExecutionPolicy Bypass -Scope Process -Force
[System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072
iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1'))
choco install -y --no-progress git gzip
echo "C:\Program Files\Git\cmd" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "C:\ProgramData\chocolatey\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
# pacman is buggy on win arm64, so we avoid using it, but rely on the binary artifacts
# we download the sfx (7zip bundle) which isn't fully set up, but the binaries we need to build work
- name: Install msys2 x64
run: |
$url="https://github.com/msys2/msys2-installer/releases/download/2024-07-27/msys2-base-x86_64-20240727.sfx.exe"
write-host "Downloading MSYS2"
Invoke-WebRequest -Uri "$url" -outfile "${env:RUNNER_TEMP}\msys2.exe"
write-host "Installing msys2"
Start-Process "${env:RUNNER_TEMP}\msys2.exe" -ArgumentList @(
'-y', '-oC:\'
) -NoNewWindow -Wait
echo "c:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
# since pacman isn't reliable, we just download the tar file and extract directly
- name: Downloading and extracting msys2 make tar file
run: |
$url="https://mirror.msys2.org/msys/x86_64/make-4.4.1-2-x86_64.pkg.tar.zst"
write-host "Downloading make"
Invoke-WebRequest -Uri "$url" -outfile c:\msys64\make.tar.zst
cd c:\msys64; tar -xf make.tar.zst
rm c:\msys64\make.tar.zst
- name: Verify Make works properly
run: |
echo $env:PATH
make --version
- name: Install Visual Studio 2022
run: |
$components = @(
"Microsoft.VisualStudio.Component.CoreEditor",
"Microsoft.VisualStudio.Workload.CoreEditor",
"Microsoft.VisualStudio.Component.Roslyn.Compiler",
"Microsoft.Component.MSBuild",
"Microsoft.VisualStudio.Component.TextTemplating",
"Microsoft.VisualStudio.Component.Debugger.JustInTime",
"Microsoft.VisualStudio.Component.VC.CoreIde",
"Microsoft.VisualStudio.Component.VC.Tools.x86.x64",
"Microsoft.VisualStudio.Component.Windows11SDK.22621",
"Microsoft.VisualStudio.Component.VC.Tools.ARM64EC",
"Microsoft.VisualStudio.Component.VC.Tools.ARM64",
"Microsoft.VisualStudio.Component.VC.ATL",
"Microsoft.VisualStudio.Component.VC.ATL.ARM64",
"Microsoft.VisualStudio.Component.Graphics",
"Microsoft.VisualStudio.Component.VC.Redist.14.Latest",
"Microsoft.VisualStudio.ComponentGroup.NativeDesktop.Core",
"Microsoft.VisualStudio.Component.Windows11Sdk.WindowsPerformanceToolkit",
"Microsoft.VisualStudio.Component.CppBuildInsights",
"Microsoft.VisualStudio.Component.VC.DiagnosticTools",
"Microsoft.VisualStudio.ComponentGroup.WebToolsExtensions.CMake",
"Microsoft.VisualStudio.Component.VC.CMake.Project",
"Microsoft.VisualStudio.Component.VC.ASAN",
"Microsoft.VisualStudio.Component.Vcpkg",
"Microsoft.VisualStudio.Workload.NativeDesktop"
)
$config = @{
"version" = "1.0"
"components" = $components
"extensions" = @()
}
$configPath = "${env:RUNNER_TEMP}\vsconfig"
$config | ConvertTo-Json | Out-File -FilePath $configPath
$bootstrapperFilePath = "${env:RUNNER_TEMP}\vs_community.exe"
write-host "Downloading Visual Studio 2022"
Invoke-WebRequest -Uri "https://aka.ms/vs/17/release/vs_community.exe" -outfile $bootstrapperFilePath
$bootstrapperArgumentList = ('/c', $bootstrapperFilePath, '--config', $configPath, '--quiet', '--wait' )
write-host "Installing Visual Studio 2022"
$process = Start-Process -FilePath cmd.exe -ArgumentList $bootstrapperArgumentList -Wait -PassThru
$exitCode = $process.ExitCode
write-host $exitCode
# pacman in mingw/msys2 is ~broken on windows arm right now - hangs consistently during attempts to install
# so we'll use this alternative GCC binary
- name: Install llvm-mingw GCC
run: |
$gcc_url="https://github.com/mstorsjo/llvm-mingw/releases/download/20240619/llvm-mingw-20240619-ucrt-aarch64.zip"
write-host "Downloading llvm-mingw"
Invoke-WebRequest -Uri "${gcc_url}" -OutFile "${env:RUNNER_TEMP}\gcc.zip"
write-host "Unpacking llvm-mingw"
expand-archive -path "${env:RUNNER_TEMP}\gcc.zip" -destinationpath "c:\"
mv c:\llvm-mingw-* c:\llvm-mingw
echo "c:\llvm-mingw\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- name: Verify GCC
run: |
echo $env:PATH
gcc --version
- uses: actions/checkout@v4
- name: Set Version
run: |
$ver=${env:GITHUB_REF_NAME}.trim("v")
echo VERSION=$ver | Out-File -FilePath ${env:GITHUB_ENV} -Encoding utf8 -Append
- uses: 'google-github-actions/auth@v2'
with:
project_id: 'ollama'
credentials_json: '${{ secrets.GOOGLE_SIGNING_CREDENTIALS }}'
- run: echo "${{ vars.OLLAMA_CERT }}" | Out-File -FilePath ollama_inc.crt -Encoding utf8
- name: install Windows SDK 8.1 to get signtool
run: |
$ErrorActionPreference = "Stop"
write-host "downloading SDK"
Invoke-WebRequest -Uri "https://go.microsoft.com/fwlink/p/?LinkId=323507" -OutFile "${env:RUNNER_TEMP}\sdksetup.exe"
Start-Process "${env:RUNNER_TEMP}\sdksetup.exe" -ArgumentList @("/q") -NoNewWindow -Wait
write-host "Win SDK 8.1 installed"
gci -path 'C:\Program Files (x86)\Windows Kits\' -r -fi 'signtool.exe'
- name: install signing plugin
run: |
$ErrorActionPreference = "Stop"
write-host "downloading plugin"
Invoke-WebRequest -Uri "https://github.com/GoogleCloudPlatform/kms-integrations/releases/download/cng-v1.0/kmscng-1.0-windows-amd64.zip" -OutFile "${env:RUNNER_TEMP}\plugin.zip"
Expand-Archive -Path "${env:RUNNER_TEMP}\plugin.zip" -DestinationPath ${env:RUNNER_TEMP}\plugin\
write-host "Installing plugin"
& "${env:RUNNER_TEMP}\plugin\*\kmscng.msi" /quiet
write-host "plugin installed"
- uses: actions/setup-go@v5
with:
go-version-file: go.mod
cache: true
- run: go get ./...
- run: |
& .\scripts\build_windows.ps1 gatherDependencies sign buildInstaller distZip
env:
KEY_CONTAINER: ${{ vars.KEY_CONTAINER }}
$gopath=(get-command go).source | split-path -parent
$gccpath=(get-command gcc).source | split-path -parent
import-module 'C:\Program Files\Microsoft Visual Studio\2022\Community\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -Arch arm64 -vsinstallpath 'C:\Program Files\Microsoft Visual Studio\2022\Community' -skipautomaticlocation
$env:PATH="$gopath;$gccpath;$env:PATH"
echo $env:PATH
$env:ARCH="arm64"
.\scripts\build_windows.ps1 buildOllama buildApp gatherDependencies sign distZip
name: 'Windows Build'
- uses: actions/upload-artifact@v4
with:
name: windows-arm64
path: |
dist/windows-arm64/**
dist/windows-arm64-app.exe
dist/ollama-windows-arm64.zip
# Import the prior generation steps plus the full arm64 build, and build the final windows assets
build-windows:
environment: release
runs-on: windows
needs:
- generate-windows-cuda
- generate-windows-rocm
- generate-windows-cpu
- windows-arm64
env:
KEY_CONTAINER: ${{ vars.KEY_CONTAINER }}
steps:
- uses: actions/checkout@v4
with:
submodules: recursive
- name: Set Version
shell: bash
run: echo "VERSION=${GITHUB_REF_NAME#v}" >> $GITHUB_ENV
- uses: 'google-github-actions/auth@v2'
with:
project_id: 'ollama'
credentials_json: '${{ secrets.GOOGLE_SIGNING_CREDENTIALS }}'
- run: echo "${{ vars.OLLAMA_CERT }}" > ollama_inc.crt
- name: install Windows SDK 8.1 to get signtool
run: |
$ErrorActionPreference = "Stop"
write-host "downloading SDK"
Invoke-WebRequest -Uri "https://go.microsoft.com/fwlink/p/?LinkId=323507" -OutFile "${env:RUNNER_TEMP}\sdksetup.exe"
Start-Process "${env:RUNNER_TEMP}\sdksetup.exe" -ArgumentList @("/q") -NoNewWindow -Wait
write-host "Win SDK 8.1 installed"
gci -path 'C:\Program Files (x86)\Windows Kits\' -r -fi 'signtool.exe'
- name: install signing plugin
run: |
$ErrorActionPreference = "Stop"
write-host "downloading plugin"
Invoke-WebRequest -Uri "https://github.com/GoogleCloudPlatform/kms-integrations/releases/download/cng-v1.0/kmscng-1.0-windows-amd64.zip" -OutFile "${env:RUNNER_TEMP}\plugin.zip"
Expand-Archive -Path "${env:RUNNER_TEMP}\plugin.zip" -DestinationPath ${env:RUNNER_TEMP}\plugin\
write-host "Installing plugin"
& "${env:RUNNER_TEMP}\plugin\*\kmscng.msi" /quiet
write-host "plugin installed"
- name: Install msys2
run: |
$msys2_url="https://github.com/msys2/msys2-installer/releases/download/2024-07-27/msys2-x86_64-20240727.exe"
write-host "Downloading msys2"
Invoke-WebRequest -Uri "${msys2_url}" -OutFile "${env:RUNNER_TEMP}\msys2.exe"
write-host "Installing msys2"
Start-Process "${env:RUNNER_TEMP}\msys2.exe" -ArgumentList @("in", "--confirm-command", "--accept-messages", "--root", "C:/msys64") -NoNewWindow -Wait
echo "c:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- name: Install msys2 tools
run: |
Start-Process "c:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang", "make") -NoNewWindow -Wait
echo "C:\msys64\clang64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- name: verify tools
run: |
get-command gcc
gcc --version
get-command make
make --version
- uses: actions/setup-go@v5
with:
go-version-file: go.mod
cache: true
- run: go get
- uses: actions/download-artifact@v4
with:
name: generate-windows-cpu
path: dist/windows-amd64/
- uses: actions/download-artifact@v4
with:
name: generate-windows-cuda-11.3
path: dist/windows-amd64/
- uses: actions/download-artifact@v4
with:
name: generate-windows-cuda-12.4
path: dist/windows-amd64/
- uses: actions/download-artifact@v4
with:
name: generate-windows-rocm
path: dist/windows-amd64/
- uses: actions/download-artifact@v4
with:
name: windows-arm64
path: dist
- run: |
import-module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -vsinstallpath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -skipautomaticlocation -DevCmdArguments '-arch=x64 -no_logo'
$env:OLLAMA_SKIP_GENERATE="1"
$env:ARCH="amd64"
if (!(gcc --version | select-string -quiet clang)) { throw "wrong gcc compiler detected - must be clang" }
& .\scripts\build_windows.ps1
- uses: actions/upload-artifact@v4
with:
name: dist-windows
path: |
dist\OllamaSetup.exe
dist\ollama-windows-*.zip
dist/OllamaSetup.exe
dist/ollama-windows-*.zip
linux-build:
strategy:
matrix:
include:
- os: linux
arch: amd64
target: archive
- os: linux
arch: amd64
target: rocm
- os: linux
arch: arm64
target: archive
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
build-linux:
environment: release
needs: setup-environment
env:
GOFLAGS: ${{ needs.setup-environment.outputs.GOFLAGS }}
steps:
- uses: actions/checkout@v4
- uses: docker/setup-buildx-action@v3
- uses: docker/build-push-action@v6
with:
context: .
platforms: ${{ matrix.os }}/${{ matrix.arch }}
target: ${{ matrix.target }}
build-args: |
GOFLAGS=${{ env.GOFLAGS }}
CGO_CFLAGS=${{ env.CGO_CFLAGS }}
CGO_CXXFLAGS=${{ env.CGO_CXXFLAGS }}
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
cache-from: type=registry,ref=ollama/ollama:latest
cache-to: type=inline
- run: |
for COMPONENT in bin/* lib/ollama/*; do
case "$COMPONENT" in
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/*.so) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_v11) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_v12) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
esac
done
working-directory: dist/${{ matrix.os }}-${{ matrix.arch }}
- run: |
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
done
- uses: actions/upload-artifact@v4
with:
name: dist-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.target }}
path: |
*.tgz
# Build each Docker variant (OS, arch, and flavor) separately. Using QEMU is unreliable and slower.
docker-build-push:
strategy:
matrix:
include:
- os: linux
arch: arm64
build-args: |
CGO_CFLAGS
CGO_CXXFLAGS
GOFLAGS
- os: linux
arch: amd64
build-args: |
CGO_CFLAGS
CGO_CXXFLAGS
GOFLAGS
- os: linux
arch: amd64
suffix: '-rocm'
build-args: |
CGO_CFLAGS
CGO_CXXFLAGS
GOFLAGS
FLAVOR=rocm
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
needs: setup-environment
env:
GOFLAGS: ${{ needs.setup-environment.outputs.GOFLAGS }}
steps:
- uses: actions/checkout@v4
- uses: docker/setup-buildx-action@v3
- uses: docker/login-action@v3
with:
username: ${{ vars.DOCKER_USER }}
password: ${{ secrets.DOCKER_ACCESS_TOKEN }}
- id: build-push
uses: docker/build-push-action@v6
with:
context: .
platforms: ${{ matrix.os }}/${{ matrix.arch }}
build-args: ${{ matrix.build-args }}
outputs: type=image,name=ollama/ollama,push-by-digest=true,name-canonical=true,push=true
cache-from: type=registry,ref=ollama/ollama:latest
cache-to: type=inline
- run: |
mkdir -p ${{ matrix.os }}-${{ matrix.arch }}
echo "${{ steps.build-push.outputs.digest }}" >${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.suffix }}.txt
working-directory: ${{ runner.temp }}
- uses: actions/upload-artifact@v4
with:
name: digest-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.suffix }}
path: |
${{ runner.temp }}/${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.suffix }}.txt
# Merge Docker images for the same flavor into a single multi-arch manifest
docker-merge-push:
strategy:
matrix:
suffix: ['', '-rocm']
runs-on: linux
environment: release
needs: [docker-build-push]
strategy:
matrix:
include:
- os: linux
arch: amd64
targets: [archive, rocm]
- os: linux
arch: arm64
targets: [archive]
steps:
- uses: actions/checkout@v4
- uses: docker/setup-qemu-action@v3
- uses: docker/setup-buildx-action@v3
- run: |
apt-get update && apt-get install pigz
for TARGET in ${{ matrix.targets }}; do docker buildx build --platform $PLATFORM --target $TARGET --output type=local,dest=dist/$PLATFORM .; done
tar c -C dist/$PLATFORM . | pigz -9cv >dist/ollama-${PLATFORM//\//-}.tar.gz
env:
PLATFORM: ${{ matrix.os }}/${{ matrix.arch }}
- uses: actions/upload-artifact@v4
with:
name: dist-${{ matrix.os }}-${{ matrix.arch }}
path: |
dist/ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.gz
build-docker:
environment: release
runs-on: ubuntu-latest
strategy:
matrix:
include:
- flavor: |
latest=auto
platforms: linux/amd64,linux/arm64
build-args: [GOFLAGS]
- flavor: |
suffix=-rocm,onlatest=false
platforms: linux/amd64
build-args: [GOFLAGS, FLAVOR=rocm]
steps:
- uses: actions/checkout@v4
- uses: docker/setup-qemu-action@v2
- uses: docker/setup-buildx-action@v2
- uses: docker/login-action@v3
with:
username: ${{ vars.DOCKER_USER }}
@@ -414,27 +531,32 @@ jobs:
- id: metadata
uses: docker/metadata-action@v4
with:
flavor: |
latest=false
suffix=${{ matrix.suffix }}
flavor: ${{ matrix.flavor }}
images: |
ollama/ollama
tags: |
type=ref,enable=true,priority=600,prefix=pr-,event=pr
type=semver,pattern={{version}}
- uses: actions/download-artifact@v4
- uses: docker/build-push-action@v6
with:
pattern: digest-*
path: ${{ runner.temp }}
merge-multiple: true
- run: |
docker buildx imagetools create $(echo '${{ steps.metadata.outputs.json }}' | jq -cr '.tags | map("-t", .) | join(" ")') $(cat *-${{ matrix.suffix }}.txt | xargs printf 'ollama/ollama@%s ')
docker buildx imagetools inspect ollama/ollama:${{ steps.metadata.outputs.version }}
working-directory: ${{ runner.temp }}
context: .
push: true
platforms: ${{ matrix.platforms }}
build-args: ${{ matrix.build-args }}
tags: ${{ steps.metadata.outputs.tags }}
labels: ${{ steps.metadata.outputs.labels }}
cache-from: type=registry,ref=ollama/ollama:latest
cache-to: type=inline
provenance: false
env:
GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${{ steps.metadata.outputs.version }}\" \"-X=github.com/ollama/ollama/server.mode=release\"'"
# Aggregate all the assets and ship a release
release:
needs: [darwin-sign, windows-sign, linux-build]
needs:
- build-darwin
- build-windows
- build-linux-amd64
- build-linux-arm64
runs-on: linux
environment: release
permissions:
@@ -443,34 +565,33 @@ jobs:
GH_TOKEN: ${{ github.token }}
steps:
- uses: actions/checkout@v4
- uses: actions/download-artifact@v4
- name: Set Version
shell: bash
run: |
echo "VERSION=${GITHUB_REF_NAME#v}" >> $GITHUB_ENV
echo "RELEASE_VERSION=$(echo ${GITHUB_REF_NAME} | cut -f1 -d-)" >> $GITHUB_ENV
- name: Retrieve built artifact
uses: actions/download-artifact@v4
with:
name: dist-darwin
path: dist
- uses: actions/download-artifact@v4
with:
name: dist-windows
path: dist
- uses: actions/download-artifact@v4
with:
pattern: dist-linux-*
path: dist
pattern: dist-*
merge-multiple: true
- run: find . -type f -not -name 'sha256sum.txt' | xargs sha256sum | tee sha256sum.txt
working-directory: dist
- run: |
ls -lh dist/
(cd dist; find . -type f | xargs sha256sum > ../sha256sum.txt)
mv sha256sum.txt dist/
cat dist/sha256sum.txt
- name: Create or update Release
run: |
RELEASE_VERSION="$(echo ${GITHUB_REF_NAME} | cut -f1 -d-)"
echo "Looking for existing release for ${RELEASE_VERSION}"
OLD_TAG=$(gh release ls --json name,tagName | jq -r ".[] | select(.name == \"${RELEASE_VERSION}\") | .tagName")
echo "Looking for existing release for ${{ env.RELEASE_VERSION }}"
OLD_TAG=$(gh release ls --json name,tagName | jq -r ".[] | select(.name == \"${{ env.RELEASE_VERSION }}\") | .tagName")
if [ -n "$OLD_TAG" ]; then
echo "Updating release ${RELEASE_VERSION} to point to new tag ${GITHUB_REF_NAME}"
echo "Updating release ${{ env.RELEASE_VERSION }} to point to new tag ${GITHUB_REF_NAME}"
gh release edit ${OLD_TAG} --tag ${GITHUB_REF_NAME}
else
echo "Creating new release ${RELEASE_VERSION} pointing to tag ${GITHUB_REF_NAME}"
echo "Creating new release ${{ env.RELEASE_VERSION }} pointing to tag ${GITHUB_REF_NAME}"
gh release create ${GITHUB_REF_NAME} \
--title ${RELEASE_VERSION} \
--title ${{ env.RELEASE_VERSION }} \
--draft \
--generate-notes \
--prerelease

View File

@@ -40,113 +40,28 @@ jobs:
linux:
needs: [changes]
if: needs.changes.outputs.changed == 'True'
if: ${{ needs.changes.outputs.changed == 'True' }}
strategy:
matrix:
include:
- preset: CPU
- preset: CUDA
container: nvidia/cuda:11.8.0-devel-ubuntu22.04
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
- preset: ROCm
container: rocm/dev-ubuntu-22.04:6.1.2
- container: nvidia/cuda:11.8.0-devel-ubuntu22.04
preset: CUDA
- container: rocm/dev-ubuntu-22.04:6.1.2
preset: ROCm
extra-packages: rocm-libs
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm'
runs-on: linux
runs-on: ubuntu-latest
container: ${{ matrix.container }}
steps:
- uses: actions/checkout@v4
- run: |
[ -n "${{ matrix.container }}" ] || sudo=sudo
$sudo apt-get update
$sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }}
apt-get update
apt-get install -y cmake pkg-config ${{ matrix.extra-packages }}
env:
DEBIAN_FRONTEND: noninteractive
- uses: actions/cache@v4
with:
path: /github/home/.cache/ccache
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
- run: |
cmake --preset ${{ matrix.preset }} ${{ matrix.flags }}
cmake --preset ${{ matrix.preset }}
cmake --build --preset ${{ matrix.preset }} --parallel
windows:
needs: [changes]
if: needs.changes.outputs.changed == 'True'
strategy:
matrix:
include:
- preset: CPU
- preset: CUDA
install: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
- preset: ROCm
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
flags: '-DAMDGPU_TARGETS=gfx1010'
runs-on: windows
steps:
- run: |
choco install -y --no-progress ccache ninja
ccache -o cache_dir=${{ github.workspace }}\.ccache
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm'
id: cache-install
uses: actions/cache/restore@v4
with:
path: |
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
key: ${{ matrix.install }}
- if: matrix.preset == 'CUDA'
name: Install CUDA ${{ matrix.cuda-version }}
run: |
$ErrorActionPreference = "Stop"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_11.3", "nvcc_11.3", "cublas_11.3", "cublas_dev_11.3")) -NoNewWindow -Wait
}
$cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path
echo "$cudaPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- if: matrix.preset == 'ROCm'
name: Install ROCm ${{ matrix.rocm-version }}
run: |
$ErrorActionPreference = "Stop"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
Start-Process -FilePath .\install.exe -ArgumentList '-install' -NoNewWindow -Wait
}
$hipPath = (Resolve-Path "C:\Program Files\AMD\ROCm\*").path
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
uses: actions/cache/save@v4
with:
path: |
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
key: ${{ matrix.install }}
- uses: actions/checkout@v4
- uses: actions/cache@v4
with:
path: ${{ github.workspace }}\.ccache
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
- run: |
Import-Module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -VsInstallPath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
cmake --build --parallel --preset "${{ matrix.preset }}"
env:
CMAKE_GENERATOR: Ninja
go_mod_tidy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: check that 'go mod tidy' is clean
run: go mod tidy --diff || (echo "Please run 'go mod tidy'." && exit 1)
test:
strategy:
matrix:
@@ -154,82 +69,15 @@ jobs:
runs-on: ${{ matrix.os }}
env:
CGO_ENABLED: '1'
GOEXPERIMENT: 'synctest'
steps:
- name: checkout
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2
- name: cache restore
uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
# Note: unlike the other setups, this is only grabbing the mod download
# cache, rather than the whole mod directory, as the download cache
# contains zips that can be unpacked in parallel faster than they can be
# fetched and extracted by tar
path: |
~/.cache/go-build
~/go/pkg/mod/cache
~\AppData\Local\go-build
# NOTE: The -3- here should be incremented when the scheme of data to be
# cached changes (e.g. path above changes).
key: ${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-${{ hashFiles('**/go.sum') }}-${{ github.run_id }}
restore-keys: |
${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-${{ hashFiles('**/go.sum') }}
${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-
- name: Setup Go
uses: actions/setup-go@v5
with:
# The caching strategy of setup-go is less than ideal, and wastes
# time by not saving artifacts due to small failures like the linter
# complaining, etc. This means subsequent have to rebuild their world
# again until all checks pass. For instance, if you mispell a word,
# you're punished until you fix it. This is more hostile than
# helpful.
cache: false
go-version-file: go.mod
# It is tempting to run this in a platform independent way, but the past
# shows this codebase will see introductions of platform specific code
# generation, and so we need to check this per platform to ensure we
# don't abuse go generate on specific platforms.
- name: check that 'go generate' is clean
if: always()
run: |
go generate ./...
git diff --name-only --exit-code || (echo "Please run 'go generate ./...'." && exit 1)
- name: go test
if: always()
run: go test -count=1 -benchtime=1x ./...
# TODO(bmizerany): replace this heavy tool with just the
# tools/checks/binaries we want and then make them all run in parallel
# across jobs, not on a single tiny vm on Github Actions.
- uses: golangci/golangci-lint-action@v6
with:
args: --timeout 10m0s -v
- name: cache save
# Always save the cache, even if the job fails. The artifacts produced
# during the building of test binaries are not all for naught. They can
# be used to speed up subsequent runs.
if: always()
uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
with:
# Note: unlike the other setups, this is only grabbing the mod download
# cache, rather than the whole mod directory, as the download cache
# contains zips that can be unpacked in parallel faster than they can be
# fetched and extracted by tar
path: |
~/.cache/go-build
~/go/pkg/mod/cache
~\AppData\Local\go-build
# NOTE: The -3- here should be incremented when the scheme of data to be
# cached changes (e.g. path above changes).
key: ${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-${{ hashFiles('**/go.sum') }}-${{ github.run_id }}
- run: go test ./...
patches:
runs-on: ubuntu-latest
@@ -237,5 +85,5 @@ jobs:
- uses: actions/checkout@v4
- name: Verify patches apply cleanly and do not change files
run: |
make -f Makefile.sync clean checkout apply-patches sync
git diff --compact-summary --exit-code
make -f Makefile2 clean checkout sync
git diff --compact-summary --exit-code

7
.gitignore vendored
View File

@@ -4,13 +4,14 @@
.venv
.swp
dist
build
ollama
.cache
*.exe
.idea
test_data
*.crt
__debug_bin*
llama/build
__debug_bin*
llama/vendor
/ollama
model/model_test/testdata/*/
!model/model_test/testdata/*.*

View File

@@ -6,6 +6,8 @@ linters:
- bidichk
- bodyclose
- containedctx
- contextcheck
- errcheck
- gocheckcompilerdirectives
- gofmt
- gofumpt
@@ -21,11 +23,10 @@ linters:
- staticcheck
- tenv
- unconvert
- unused
- usestdlibvars
- wastedassign
- whitespace
disable:
- usestdlibvars
- errcheck
linters-settings:
staticcheck:
checks:
@@ -38,4 +39,5 @@ severity:
- gofmt
- goimports
- intrange
- usestdlibvars
severity: info

View File

@@ -21,30 +21,12 @@ set(GGML_BACKEND_SHARED ON)
set(GGML_SCHED_MAX_COPIES 4)
set(GGML_LLAMAFILE ON)
set(GGML_CPU_ALL_VARIANTS ON)
set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128)
set(GGML_CUDA_GRAPHS ON)
set(GGML_CUDA_FA ON)
set(GGML_CUDA_COMPRESSION_MODE default)
if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_SYSTEM_PROCESSOR MATCHES "arm|aarch64|ARM64|ARMv[0-9]+"))
set(GGML_CPU_ALL_VARIANTS ON)
endif()
if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
set(CMAKE_BUILD_RPATH "@loader_path")
set(CMAKE_INSTALL_RPATH "@loader_path")
endif()
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
set(OLLAMA_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/lib/ollama)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR})
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_RELEASE ${OLLAMA_BUILD_DIR})
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR})
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${OLLAMA_BUILD_DIR})
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include)
@@ -55,79 +37,16 @@ set(GGML_CPU ON)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE)
get_target_property(CPU_VARIANTS ggml-cpu MANUALLY_ADDED_DEPENDENCIES)
if(NOT CPU_VARIANTS)
set(CPU_VARIANTS "ggml-cpu")
endif()
install(TARGETS ggml-base ${CPU_VARIANTS}
RUNTIME_DEPENDENCIES
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CPU
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CPU
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CPU
)
check_language(CUDA)
if(CMAKE_CUDA_COMPILER)
if(CMAKE_VERSION VERSION_GREATER_EQUAL "3.24" AND NOT CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES "native")
endif()
find_package(CUDAToolkit)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
set(OLLAMA_CUDA_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/cuda_v${CUDAToolkit_VERSION_MAJOR})
install(TARGETS ggml-cuda
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_CUDA_INSTALL_DIR} COMPONENT CUDA
LIBRARY DESTINATION ${OLLAMA_CUDA_INSTALL_DIR} COMPONENT CUDA
)
endif()
set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX "^gfx(906|908|90a|1200|1201):xnack[+-]$"
CACHE STRING
"Regular expression describing AMDGPU_TARGETS not supported on Windows. Override to force building these targets. Default \"^gfx(906|908|90a|1200|1201):xnack[+-]$\"."
)
check_language(HIP)
if(CMAKE_HIP_COMPILER)
set(HIP_PLATFORM "amd")
find_package(hip REQUIRED)
if(NOT AMDGPU_TARGETS)
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012]|120[01])$")
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
endif()
if(AMDGPU_TARGETS)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
if (WIN32)
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
endif()
target_compile_definitions(ggml-hip PRIVATE GGML_HIP_NO_VMM)
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
install(TARGETS ggml-hip
RUNTIME_DEPENDENCIES
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf
PRE_EXCLUDE_REGEXES ".*"
POST_EXCLUDE_REGEXES "system32"
RUNTIME DESTINATION ${OLLAMA_HIP_INSTALL_DIR} COMPONENT HIP
LIBRARY DESTINATION ${OLLAMA_HIP_INSTALL_DIR} COMPONENT HIP
)
foreach(HIP_LIB_BIN_INSTALL_DIR IN ITEMS ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR})
if(EXISTS ${HIP_LIB_BIN_INSTALL_DIR}/rocblas)
install(DIRECTORY ${HIP_LIB_BIN_INSTALL_DIR}/rocblas DESTINATION ${OLLAMA_HIP_INSTALL_DIR} COMPONENT HIP)
break()
endif()
endforeach()
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
endif()

View File

@@ -3,16 +3,11 @@
"configurePresets": [
{
"name": "Default",
"binaryDir": "${sourceDir}/build",
"installDir": "${sourceDir}/dist",
"binaryDir": "${sourceDir}/dist/build",
"cacheVariables": {
"CMAKE_BUILD_TYPE": "Release"
}
},
{
"name": "CPU",
"inherits": [ "Default" ]
},
{
"name": "CUDA",
"inherits": [ "Default" ]
@@ -21,16 +16,14 @@
"name": "CUDA 11",
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;70;75;80;86",
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;62;70;72;75;80;86"
}
},
{
"name": "CUDA 12",
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120",
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
"CMAKE_CUDA_ARCHITECTURES": "60;61;62;70;72;75;80;86;87;89;90;90a"
}
},
{
@@ -49,16 +42,13 @@
},
{
"name": "ROCm",
"inherits": [ "Default" ],
"cacheVariables": {
"CMAKE_HIP_PLATFORM": "amd"
}
"inherits": [ "Default" ]
},
{
"name": "ROCm 6",
"inherits": [ "ROCm" ],
"cacheVariables": {
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
"CMAKE_HIP_ARCHITECTURES": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102"
}
}
],
@@ -66,11 +56,7 @@
{
"name": "Default",
"configurePreset": "Default",
"configuration": "Release"
},
{
"name": "CPU",
"configurePreset": "Default",
"configuration": "Release",
"targets": [ "ggml-cpu" ]
},
{

View File

@@ -6,6 +6,8 @@ Thank you for your interest in contributing to Ollama! Here are a few guidelines
See the [development documentation](./docs/development.md) for instructions on how to build and run Ollama locally.
## Pull requests
### Ideal issues
* [Bugs](https://github.com/ollama/ollama/issues?q=is%3Aissue+is%3Aopen+label%3Abug): issues where Ollama stops working or where it results in an unexpected error.
@@ -24,64 +26,11 @@ See the [development documentation](./docs/development.md) for instructions on h
* Changes that add significant friction to the user experience
* Changes that create a large future maintenance burden for maintainers and contributors
## Proposing a (non-trivial) change
### Best practices
> By "non-trivial", we mean a change that is not a bug fix or small
> documentation update. If you are unsure, please ask us on our [Discord
> server](https://discord.gg/ollama).
Before opening a non-trivial Pull Request, please open an issue to discuss the change and
get feedback from the maintainers. This helps us understand the context of the
change and how it fits into Ollama's roadmap and prevents us from duplicating
work or you from spending time on a change that we may not be able to accept.
Tips for proposals:
* Explain the problem you are trying to solve, not what you are trying to do.
* Explain why the change is important.
* Explain how the change will be used.
* Explain how the change will be tested.
Additionally, for bonus points: Provide draft documentation you would expect to
see if the change were accepted.
## Pull requests
**Commit messages**
The title should look like:
<package>: <short description>
The package is the most affected Go package. If the change does not affect Go
code, then use the directory name instead. Changes to a single well-known
file in the root directory may use the file name.
The short description should start with a lowercase letter and be a
continuation of the sentence:
"This changes Ollama to..."
Examples:
llm/backend/mlx: support the llama architecture
CONTRIBUTING: provide clairity on good commit messages, and bad
Bad Examples:
feat: add more emoji
fix: was not using famous web framework
chore: generify code
**Tests**
Please include tests. Strive to test behavior, not implementation.
**New dependencies**
Dependencies should be added sparingly. If you are adding a new dependency,
please explain why it is necessary and what other ways you attempted that
did not work without it.
* Commit messages: please leave both a title and a description in your commit messages. The title should be a short summary of the changes, with a leading word that explains the section of the code being changed (e.g. `api: fix parsing of prompt field`) . In the description, leave a short 2-3 sentences that explain more about the change and its impact.
* Tests: please add test coverage to changes where possible.
* Minimize dependencies: avoid adding new dependencies unless absolutely necessary.
## Need help?

View File

@@ -2,24 +2,22 @@
ARG FLAVOR=${TARGETARCH}
ARG ROCMVERSION=6.3.3
ARG ROCMVERSION=6.1.2
ARG JETPACK5VERSION=r35.4.1
ARG JETPACK6VERSION=r36.4.0
ARG JETPACK6VERSION=r36.2.0
ARG CMAKEVERSION=3.31.2
# CUDA v11 requires gcc v10. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
RUN yum install -y yum-utils \
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
FROM --platform=linux/amd64 rocm/dev-centos-7:${ROCMVERSION}-complete AS base-amd64
RUN sed -i -e 's/mirror.centos.org/vault.centos.org/g' -e 's/^#.*baseurl=http/baseurl=http/g' -e 's/^mirrorlist=http/#mirrorlist=http/g' /etc/yum.repos.d/*.repo \
&& yum install -y yum-utils devtoolset-10-gcc devtoolset-10-gcc-c++ \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo \
&& curl -s -L https://github.com/ccache/ccache/releases/download/v4.10.2/ccache-4.10.2-linux-x86_64.tar.xz | tar -Jx -C /usr/local/bin --strip-components 1
ENV PATH=/opt/rh/devtoolset-10/root/usr/bin:/opt/rh/devtoolset-11/root/usr/bin:$PATH
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
FROM --platform=linux/arm64 rockylinux:8 AS base-arm64
# install epel-release for ccache
RUN yum install -y yum-utils epel-release \
&& dnf install -y clang ccache \
&& yum install -y clang ccache \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
ENV CC=clang CXX=clang++
@@ -31,37 +29,29 @@ COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
ENV LDFLAGS=-s
FROM base AS cpu
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
# amd64 uses gcc which requires devtoolset-11 for AVX extensions while arm64 uses clang
RUN if [ "$(uname -m)" = "x86_64" ]; then yum install -y devtoolset-11-gcc devtoolset-11-gcc-c++; fi
ENV PATH=/opt/rh/devtoolset-11/root/usr/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CPU' \
&& cmake --build --parallel --preset 'CPU' \
&& cmake --install build --component CPU --strip --parallel 8
cmake --preset 'Default' && cmake --build --parallel --preset 'Default'
FROM base AS cuda-11
ARG CUDA11VERSION=11.3
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
RUN yum install -y cuda-toolkit-${CUDA11VERSION//./-}
ENV PATH=/usr/local/cuda-11/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 11' \
&& cmake --build --parallel --preset 'CUDA 11' \
&& cmake --install build --component CUDA --strip --parallel 8
cmake --preset 'CUDA 11' && cmake --build --parallel --preset 'CUDA 11'
FROM base AS cuda-12
ARG CUDA12VERSION=12.8
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
ARG CUDA12VERSION=12.4
RUN yum install -y cuda-toolkit-${CUDA12VERSION//./-}
ENV PATH=/usr/local/cuda-12/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 12' \
&& cmake --build --parallel --preset 'CUDA 12' \
&& cmake --install build --component CUDA --strip --parallel 8
cmake --preset 'CUDA 12' && cmake --build --parallel --preset 'CUDA 12'
FROM base AS rocm-6
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'ROCm 6' \
&& cmake --build --parallel --preset 'ROCm 6' \
&& cmake --install build --component HIP --strip --parallel 8
cmake --preset 'ROCm 6' && cmake --build --parallel --preset 'ROCm 6'
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
ARG CMAKEVERSION
@@ -70,9 +60,7 @@ RUN apt-get update && apt-get install -y curl ccache \
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'JetPack 5' \
&& cmake --build --parallel --preset 'JetPack 5' \
&& cmake --install build --component CUDA --strip --parallel 8
cmake --preset 'JetPack 5' && cmake --build --parallel --preset 'JetPack 5'
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
ARG CMAKEVERSION
@@ -81,16 +69,13 @@ RUN apt-get update && apt-get install -y curl ccache \
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'JetPack 6' \
&& cmake --build --parallel --preset 'JetPack 6' \
&& cmake --install build --component CUDA --strip --parallel 8
cmake --preset 'JetPack 6' && cmake --build --parallel --preset 'JetPack 6'
FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama
COPY go.mod go.sum .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ARG GOVERSION=1.23.4
RUN curl -fsSL https://golang.org/dl/go${GOVERSION}.linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download
WORKDIR /go/src/github.com/ollama/ollama
COPY . .
ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1
@@ -98,20 +83,65 @@ RUN --mount=type=cache,target=/root/.cache/go-build \
go build -trimpath -buildmode=pie -o /bin/ollama .
FROM --platform=linux/amd64 scratch AS amd64
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
COPY --from=cuda-11 --chmod=644 \
dist/build/lib/libggml-cuda.so \
/usr/local/cuda/lib64/libcublas.so.11 \
/usr/local/cuda/lib64/libcublasLt.so.11 \
/usr/local/cuda/lib64/libcudart.so.11.0 \
/lib/ollama/cuda_v11/
COPY --from=cuda-12 --chmod=644 \
dist/build/lib/libggml-cuda.so \
/usr/local/cuda/lib64/libcublas.so.12 \
/usr/local/cuda/lib64/libcublasLt.so.12 \
/usr/local/cuda/lib64/libcudart.so.12 \
/lib/ollama/cuda_v12/
FROM --platform=linux/arm64 scratch AS arm64
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_jetpack5
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_jetpack6
COPY --from=cuda-11 --chmod=644 \
dist/build/lib/libggml-cuda.so \
/usr/local/cuda/lib64/libcublas.so.11 \
/usr/local/cuda/lib64/libcublasLt.so.11 \
/usr/local/cuda/lib64/libcudart.so.11.0 \
/lib/ollama/cuda_v11/
COPY --from=cuda-12 --chmod=644 \
dist/build/lib/libggml-cuda.so \
/usr/local/cuda/lib64/libcublas.so.12 \
/usr/local/cuda/lib64/libcublasLt.so.12 \
/usr/local/cuda/lib64/libcudart.so.12 \
/lib/ollama/cuda_v12/
COPY --from=jetpack-5 --chmod=644 \
dist/build/lib/libggml-cuda.so \
/usr/local/cuda/lib64/libcublas.so.11 \
/usr/local/cuda/lib64/libcublasLt.so.11 \
/usr/local/cuda/lib64/libcudart.so.11.0 \
/lib/ollama/cuda_jetpack5/
COPY --from=jetpack-6 --chmod=644 \
dist/build/lib/libggml-cuda.so \
/usr/local/cuda/lib64/libcublas.so.12 \
/usr/local/cuda/lib64/libcublasLt.so.12 \
/usr/local/cuda/lib64/libcudart.so.12 \
/lib/ollama/cuda_jetpack6/
FROM scratch AS rocm
COPY --from=rocm-6 dist/lib/ollama/rocm /lib/ollama/rocm
FROM --platform=linux/arm64 scratch AS rocm
COPY --from=rocm-6 --chmod=644 \
dist/build/lib/libggml-hip.so \
/opt/rocm/lib/libamdhip64.so.6 \
/opt/rocm/lib/libhipblas.so.2 \
/opt/rocm/lib/librocblas.so.4 \
/opt/rocm/lib/libamd_comgr.so.2 \
/opt/rocm/lib/libhsa-runtime64.so.1 \
/opt/rocm/lib/librocprofiler-register.so.0 \
/opt/amdgpu/lib64/libdrm_amdgpu.so.1 \
/opt/amdgpu/lib64/libdrm.so.2 \
/usr/lib64/libnuma.so.1 \
/lib/ollama/rocm/
COPY --from=rocm-6 /opt/rocm/lib/rocblas/ /lib/ollama/rocm/rocblas/
FROM ${FLAVOR} AS archive
COPY --from=cpu dist/lib/ollama /lib/ollama
COPY --from=cpu --chmod=644 \
dist/build/lib/libggml-base.so \
dist/build/lib/libggml-cpu-*.so \
/lib/ollama/
COPY --from=build /bin/ollama /bin/ollama
FROM ubuntu:20.04
@@ -119,10 +149,10 @@ RUN apt-get update \
&& apt-get install -y ca-certificates \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
COPY --from=archive /bin /usr/bin
COPY --from=archive /bin/ /usr/bin/
ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
COPY --from=archive /lib/ollama /usr/lib/ollama
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
COPY --from=archive /lib/ollama/ /usr/lib/ollama/
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64:/usr/lib/ollama
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
ENV NVIDIA_VISIBLE_DEVICES=all
ENV OLLAMA_HOST=0.0.0.0:11434

View File

@@ -1,60 +0,0 @@
UPSTREAM=https://github.com/ggerganov/llama.cpp.git
WORKDIR=llama/vendor
FETCH_HEAD=2016f07bd106c73699ecbaace80f55db5ed95dac
.PHONY: help
help:
@echo "Available targets:"
@echo " sync Sync with upstream repositories"
@echo " checkout Checkout upstream repository"
@echo " apply-patches Apply patches to local repository"
@echo " format-patches Format patches from local repository"
@echo " clean Clean local repository"
@echo
@echo "Example:"
@echo " make -f $(lastword $(MAKEFILE_LIST)) clean sync"
.PHONY: sync
sync: llama/build-info.cpp llama/llama.cpp ml/backend/ggml/ggml
.PHONY: llama/build-info.cpp
llama/build-info.cpp: llama/build-info.cpp.in
sed -e 's|@FETCH_HEAD@|$(FETCH_HEAD)|' $< > $@
.PHONY: llama/llama.cpp
llama/llama.cpp: llama/vendor/
rsync -arvzc -f "merge $@/.rsync-filter" $< $@
.PHONY: ml/backend/ggml/ggml
ml/backend/ggml/ggml: llama/vendor/ggml/
rsync -arvzc -f "merge $@/.rsync-filter" $< $@
PATCHES=$(wildcard llama/patches/*.patch)
.PHONY: apply-patches
.NOTPARALLEL:
apply-patches: $(addsuffix ed, $(PATCHES))
%.patched: %.patch
@if git -c user.name=nobody -c 'user.email=<>' -C $(WORKDIR) am -3 $(realpath $<); then touch $@; else git -C $(WORKDIR) am --abort; exit 1; fi
.PHONY: checkout
checkout: $(WORKDIR)
git -C $(WORKDIR) fetch
git -C $(WORKDIR) checkout -f $(FETCH_HEAD)
$(WORKDIR):
git clone $(UPSTREAM) $(WORKDIR)
.PHONE: format-patches
format-patches: llama/patches
git -C $(WORKDIR) format-patch \
--no-signature \
--no-numbered \
--zero-commit \
-o $(realpath $<) \
$(FETCH_HEAD)
.PHONE: clean
clean: checkout
$(RM) $(addsuffix ed, $(PATCHES))

46
Makefile2 Normal file
View File

@@ -0,0 +1,46 @@
UPSTREAM=https://github.com/ggerganov/llama.cpp.git
WORKDIR=llama/vendor
FETCH_HEAD=46e3556e01b824e52395fb050b29804b6cff2a7c
all: sync
.PHONY: sync
sync: llama/llama.cpp ml/backend/ggml/ggml
.PHONY: llama/llama.cpp
llama/llama.cpp: llama/vendor/ apply_patches
rsync -arvzc -f "merge $@/.rsync-filter" $< $@
.PHONY: ml/backend/ggml/ggml apply_patches
ml/backend/ggml/ggml: llama/vendor/ggml/ apply_patches
rsync -arvzc -f "merge $@/.rsync-filter" $< $@
PATCHES=$(wildcard llama/patches/*.patch)
.PHONY: apply_patches
.NOTPARALLEL:
apply_patches: $(addsuffix ed, $(PATCHES))
%.patched: %.patch
@if git -c user.name=nobody -c 'user.email=<>' -C $(WORKDIR) am -3 $(realpath $<); then touch $@; else git -C $(WORKDIR) am --abort; exit 1; fi
.PHONY: checkout
checkout: $(WORKDIR)
git -C $(WORKDIR) fetch
git -C $(WORKDIR) checkout -f $(FETCH_HEAD)
$(WORKDIR):
git clone $(UPSTREAM) $(WORKDIR)
.PHONE: format_patches
format_patches: llama/patches
git -C $(WORKDIR) format-patch \
--no-signature \
--no-numbered \
--zero-commit \
-o $(realpath $<) \
$(FETCH_HEAD)
.PHONE: clean
clean: checkout
$(RM) $(addsuffix ed, $(PATCHES))

110
README.md
View File

@@ -1,5 +1,5 @@
<div align="center">
  <a href="https://ollama.com">
  <a href="https://ollama.com" />
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
</a>
</div>
@@ -18,7 +18,7 @@ Get up and running with large language models.
### Linux
```shell
```
curl -fsSL https://ollama.com/install.sh | sh
```
@@ -42,7 +42,7 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
To run and chat with [Llama 3.2](https://ollama.com/library/llama3.2):
```shell
```
ollama run llama3.2
```
@@ -54,13 +54,6 @@ Here are some example models that can be downloaded:
| Model | Parameters | Size | Download |
| ------------------ | ---------- | ----- | -------------------------------- |
| Gemma 3 | 1B | 815MB | `ollama run gemma3:1b` |
| Gemma 3 | 4B | 3.3GB | `ollama run gemma3` |
| Gemma 3 | 12B | 8.1GB | `ollama run gemma3:12b` |
| Gemma 3 | 27B | 17GB | `ollama run gemma3:27b` |
| QwQ | 32B | 20GB | `ollama run qwq` |
| DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` |
| DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` |
| Llama 3.3 | 70B | 43GB | `ollama run llama3.3` |
| Llama 3.2 | 3B | 2.0GB | `ollama run llama3.2` |
| Llama 3.2 | 1B | 1.3GB | `ollama run llama3.2:1b` |
@@ -69,7 +62,10 @@ Here are some example models that can be downloaded:
| Llama 3.1 | 8B | 4.7GB | `ollama run llama3.1` |
| Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` |
| Phi 4 | 14B | 9.1GB | `ollama run phi4` |
| Phi 4 Mini | 3.8B | 2.5GB | `ollama run phi4-mini` |
| Phi 3 Mini | 3.8B | 2.3GB | `ollama run phi3` |
| Gemma 2 | 2B | 1.6GB | `ollama run gemma2:2b` |
| Gemma 2 | 9B | 5.5GB | `ollama run gemma2` |
| Gemma 2 | 27B | 16GB | `ollama run gemma2:27b` |
| Mistral | 7B | 4.1GB | `ollama run mistral` |
| Moondream 2 | 1.4B | 829MB | `ollama run moondream` |
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
@@ -77,7 +73,7 @@ Here are some example models that can be downloaded:
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
| LLaVA | 7B | 4.5GB | `ollama run llava` |
| Granite-3.2 | 8B | 4.9GB | `ollama run granite3.2` |
| Solar | 10.7B | 6.1GB | `ollama run solar` |
> [!NOTE]
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
@@ -96,13 +92,13 @@ Ollama supports importing GGUF models in the Modelfile:
2. Create the model in Ollama
```shell
```
ollama create example -f Modelfile
```
3. Run the model
```shell
```
ollama run example
```
@@ -114,7 +110,7 @@ See the [guide](docs/import.md) on importing models for more information.
Models from the Ollama library can be customized with a prompt. For example, to customize the `llama3.2` model:
```shell
```
ollama pull llama3.2
```
@@ -149,13 +145,13 @@ For more information on working with a Modelfile, see the [Modelfile](docs/model
`ollama create` is used to create a model from a Modelfile.
```shell
```
ollama create mymodel -f ./Modelfile
```
### Pull a model
```shell
```
ollama pull llama3.2
```
@@ -163,13 +159,13 @@ ollama pull llama3.2
### Remove a model
```shell
```
ollama rm llama3.2
```
### Copy a model
```shell
```
ollama cp llama3.2 my-model
```
@@ -188,39 +184,37 @@ I'm a basic program that prints the famous "Hello, world!" message to the consol
```
ollama run llava "What's in this image? /Users/jmorgan/Desktop/smile.png"
The image features a yellow smiley face, which is likely the central focus of the picture.
```
> **Output**: The image features a yellow smiley face, which is likely the central focus of the picture.
### Pass the prompt as an argument
```shell
ollama run llama3.2 "Summarize this file: $(cat README.md)"
```
> **Output**: Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
$ ollama run llama3.2 "Summarize this file: $(cat README.md)"
Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
```
### Show model information
```shell
```
ollama show llama3.2
```
### List models on your computer
```shell
```
ollama list
```
### List which models are currently loaded
```shell
```
ollama ps
```
### Stop a model which is currently running
```shell
```
ollama stop llama3.2
```
@@ -236,13 +230,13 @@ See the [developer guide](https://github.com/ollama/ollama/blob/main/docs/develo
Next, start the server:
```shell
```
./ollama serve
```
Finally, in a separate shell, run a model:
```shell
```
./ollama run llama3.2
```
@@ -252,7 +246,7 @@ Ollama has a REST API for running and managing models.
### Generate a response
```shell
```
curl http://localhost:11434/api/generate -d '{
"model": "llama3.2",
"prompt":"Why is the sky blue?"
@@ -261,7 +255,7 @@ curl http://localhost:11434/api/generate -d '{
### Chat with a model
```shell
```
curl http://localhost:11434/api/chat -d '{
"model": "llama3.2",
"messages": [
@@ -277,7 +271,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Web & Desktop
- [Open WebUI](https://github.com/open-webui/open-webui)
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
- [Hollama](https://github.com/fmaclen/hollama)
- [Lollms-Webui](https://github.com/ParisNeo/lollms-webui)
@@ -285,13 +278,12 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
- [Saddle](https://github.com/jikkuatwork/saddle)
- [TagSpaces](https://www.tagspaces.org) (A platform for file based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
- [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
- [Ollamac](https://github.com/kevinhermawan/Ollamac)
- [big-AGI](https://github.com/enricoros/big-AGI)
- [big-AGI](https://github.com/enricoros/big-AGI/blob/main/docs/config-local-ollama.md)
- [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
- [Amica](https://github.com/semperai/amica)
- [chatd](https://github.com/BruceMacD/chatd)
@@ -325,7 +317,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama)
- [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models)
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
- [Casibase](https://casibase.org) (An open source AI knowledge base and dialogue system combining the latest RAG, SSO, ollama support and multiple large language models.)
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
- [Shinkai Desktop](https://github.com/dcSpark/shinkai-apps) (Two click install Local AI using Ollama + Files + RAG)
@@ -348,7 +339,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot and Ollama4j
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VSCode extension for multi-file/whole-repo coding
- [Claude Dev](https://github.com/saoudrizwan/claude-dev) - VSCode extension for multi-file/whole-repo coding
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
@@ -362,7 +353,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
- [Promptery](https://github.com/promptery/promptery) (desktop client for Ollama.)
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
- [chat-ollama](https://github.com/annilq/chat-ollama) (a React Native client for Ollama)
- [SpaceLlama](https://github.com/tcsenpai/spacellama) (Firefox and Chrome extension to quickly summarize web pages with ollama in a sidebar)
- [YouLama](https://github.com/tcsenpai/youlama) (Webapp to quickly summarize any YouTube video, supporting Invidious as well)
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
@@ -379,26 +369,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Minima](https://github.com/dmayboroda/minima) (RAG with on-premises or fully local workflow)
- [aidful-ollama-model-delete](https://github.com/AidfulAI/aidful-ollama-model-delete) (User interface for simplified model cleanup)
- [Perplexica](https://github.com/ItzCrazyKns/Perplexica) (An AI-powered search engine & an open-source alternative to Perplexity AI)
- [Ollama Chat WebUI for Docker ](https://github.com/oslook/ollama-webui) (Support for local docker deployment, lightweight ollama webui)
- [AI Toolkit for Visual Studio Code](https://aka.ms/ai-tooklit/ollama-docs) (Microsoft-official VSCode extension to chat, test, evaluate models with Ollama support, and use them in your AI applications.)
- [MinimalNextOllamaChat](https://github.com/anilkay/MinimalNextOllamaChat) (Minimal Web UI for Chat and Model Control)
- [Chipper](https://github.com/TilmanGriesel/chipper) AI interface for tinkerers (Ollama, Haystack RAG, Python)
- [ChibiChat](https://github.com/CosmicEventHorizon/ChibiChat) (Kotlin-based Android app to chat with Ollama and Koboldcpp API endpoints)
- [LocalLLM](https://github.com/qusaismael/localllm) (Minimal Web-App to run ollama models on it with a GUI)
- [Ollamazing](https://github.com/buiducnhat/ollamazing) (Web extension to run Ollama models)
- [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivent endpoint with Ollama support for running locally)
- [AntSK](https://github.com/AIDotNet/AntSK) (Out-of-the-box & Adaptable RAG Chatbot)
- [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot)
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
- [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms)
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
- [Ollamb](https://github.com/hengkysteen/ollamb) (Simple yet rich in features, cross-platform built with Flutter and designed for Ollama. Try the [web demo](https://hengkysteen.github.io/demo/ollamb/).)
- [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama)
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
### Cloud
@@ -438,14 +408,10 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [SwollamaCLI](https://github.com/marcusziade/Swollama) bundled with the Swollama Swift package. [Demo](https://github.com/marcusziade/Swollama?tab=readme-ov-file#cli-usage)
- [aichat](https://github.com/sigoden/aichat) All-in-one LLM CLI tool featuring Shell Assistant, Chat-REPL, RAG, AI tools & agents, with access to OpenAI, Claude, Gemini, Ollama, Groq, and more.
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
- [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis.
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull and download models from Ollama Registry in your terminal.
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
### Apple Vision Pro
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Cross-platform AI chat app supporting Apple Vision Pro via "Designed for iPad")
- [Enchanted](https://github.com/AugustDev/enchanted)
### Database
@@ -460,10 +426,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Pacman](https://archlinux.org/packages/extra/x86_64/ollama/)
- [Gentoo](https://github.com/gentoo/guru/tree/master/app-misc/ollama)
- [Homebrew](https://formulae.brew.sh/formula/ollama)
- [Helm Chart](https://artifacthub.io/packages/helm/ollama-helm/ollama)
- [Guix channel](https://codeberg.org/tusharhero/ollama-guix)
- [Nix package](https://search.nixos.org/packages?show=ollama&from=0&size=50&sort=relevance&type=packages&query=ollama)
- [Nix package](https://search.nixos.org/packages?channel=24.05&show=ollama&from=0&size=50&sort=relevance&type=packages&query=ollama)
- [Flox](https://flox.dev/blog/ollama-part-one)
### Libraries
@@ -516,21 +481,13 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [GoLamify](https://github.com/prasad89/golamify)
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in unified API)
- [LlmTornado](https://github.com/lofcz/llmtornado) (C# library providing a unified interface for major FOSS & Commercial inference APIs)
- [Ollama for Zig](https://github.com/dravenk/ollama-zig)
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
- [Ollama for D](https://github.com/kassane/ollama-d)
### Mobile
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Lightning-fast Cross-platform AI chat app with native UI for Android, iOS and iPad)
- [Enchanted](https://github.com/AugustDev/enchanted)
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid)
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
- [Ollama Android Chat](https://github.com/sunshine0523/OllamaServer) (No need for Termux, start the Ollama service with one click on an Android device)
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
### Extensions & Plugins
@@ -574,18 +531,13 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [TextCraft](https://github.com/suncloudsmoon/TextCraft) (Copilot in Word alternative using Ollama)
- [Alfred Ollama](https://github.com/zeitlings/alfred-ollama) (Alfred Workflow)
- [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language
- [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai)
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
### Supported backends
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
### Observability
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native intergration to Ollama.
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.
- [Langfuse](https://langfuse.com/docs/integrations/ollama) is an open source LLM observability platform that enables teams to collaboratively monitor, evaluate and debug AI applications.
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.

View File

@@ -10,7 +10,7 @@
// repository].
//
// [the API documentation]: https://github.com/ollama/ollama/blob/main/docs/api.md
// [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/api/examples
// [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/examples
package api
import (
@@ -132,7 +132,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
const maxBufferSize = 512 * format.KiloByte
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf io.Reader
var buf *bytes.Buffer
if data != nil {
bts, err := json.Marshal(data)
if err != nil {

View File

@@ -1,13 +1,6 @@
package api
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
@@ -50,206 +43,3 @@ func TestClientFromEnvironment(t *testing.T) {
})
}
}
// testError represents an internal error type with status code and message
// this is used since the error response from the server is not a standard error struct
type testError struct {
message string
statusCode int
}
func (e testError) Error() string {
return e.message
}
func TestClientStream(t *testing.T) {
testCases := []struct {
name string
responses []any
wantErr string
}{
{
name: "immediate error response",
responses: []any{
testError{
message: "test error message",
statusCode: http.StatusBadRequest,
},
},
wantErr: "test error message",
},
{
name: "error after successful chunks, ok response",
responses: []any{
ChatResponse{Message: Message{Content: "partial response 1"}},
ChatResponse{Message: Message{Content: "partial response 2"}},
testError{
message: "mid-stream error",
statusCode: http.StatusOK,
},
},
wantErr: "mid-stream error",
},
{
name: "successful stream completion",
responses: []any{
ChatResponse{Message: Message{Content: "chunk 1"}},
ChatResponse{Message: Message{Content: "chunk 2"}},
ChatResponse{
Message: Message{Content: "final chunk"},
Done: true,
DoneReason: "stop",
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
flusher, ok := w.(http.Flusher)
if !ok {
t.Fatal("expected http.Flusher")
}
w.Header().Set("Content-Type", "application/x-ndjson")
for _, resp := range tc.responses {
if errResp, ok := resp.(testError); ok {
w.WriteHeader(errResp.statusCode)
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
}
return
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("failed to encode response: %v", err)
}
flusher.Flush()
}
}))
defer ts.Close()
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
var receivedChunks []ChatResponse
err := client.stream(context.Background(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error {
var resp ChatResponse
if err := json.Unmarshal(chunk, &resp); err != nil {
return fmt.Errorf("failed to unmarshal chunk: %w", err)
}
receivedChunks = append(receivedChunks, resp)
return nil
})
if tc.wantErr != "" {
if err == nil {
t.Fatal("expected error but got nil")
}
if !strings.Contains(err.Error(), tc.wantErr) {
t.Errorf("expected error containing %q, got %v", tc.wantErr, err)
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
}
}
func TestClientDo(t *testing.T) {
testCases := []struct {
name string
response any
wantErr string
}{
{
name: "immediate error response",
response: testError{
message: "test error message",
statusCode: http.StatusBadRequest,
},
wantErr: "test error message",
},
{
name: "server error response",
response: testError{
message: "internal error",
statusCode: http.StatusInternalServerError,
},
wantErr: "internal error",
},
{
name: "successful response",
response: struct {
ID string `json:"id"`
Success bool `json:"success"`
}{
ID: "msg_123",
Success: true,
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if errResp, ok := tc.response.(testError); ok {
w.WriteHeader(errResp.statusCode)
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
}
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(tc.response); err != nil {
t.Fatalf("failed to encode response: %v", err)
}
}))
defer ts.Close()
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
var resp struct {
ID string `json:"id"`
Success bool `json:"success"`
}
err := client.do(context.Background(), http.MethodPost, "/v1/messages", nil, &resp)
if tc.wantErr != "" {
if err == nil {
t.Fatalf("got nil, want error %q", tc.wantErr)
}
if err.Error() != tc.wantErr {
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
}
return
}
if err != nil {
t.Fatalf("got error %q, want nil", err)
}
if expectedResp, ok := tc.response.(struct {
ID string `json:"id"`
Success bool `json:"success"`
}); ok {
if resp.ID != expectedResp.ID {
t.Errorf("response ID mismatch: got %q, want %q", resp.ID, expectedResp.ID)
}
if resp.Success != expectedResp.Success {
t.Errorf("response Success mismatch: got %v, want %v", resp.Success, expectedResp.Success)
}
}
})
}
}

View File

@@ -2,10 +2,9 @@
Run the examples in this directory with:
```shell
```
go run example_name/main.go
```
## Chat - Chat with a model
- [chat/main.go](chat/main.go)

View File

@@ -10,9 +10,6 @@ import (
"strconv"
"strings"
"time"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
// StatusError is an error with an HTTP status code and message.
@@ -76,13 +73,13 @@ type GenerateRequest struct {
// this request.
KeepAlive *Duration `json:"keep_alive,omitempty"`
// Images is an optional list of raw image bytes accompanying this
// Images is an optional list of base64-encoded images accompanying this
// request, for multimodal models.
Images []ImageData `json:"images,omitempty"`
// Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it.
Options map[string]any `json:"options"`
Options map[string]interface{} `json:"options"`
}
// ChatRequest describes a request sent by [Client.Chat].
@@ -107,7 +104,7 @@ type ChatRequest struct {
Tools `json:"tools,omitempty"`
// Options lists model-specific options.
Options map[string]any `json:"options"`
Options map[string]interface{} `json:"options"`
}
type Tools []Tool
@@ -163,65 +160,19 @@ func (t *ToolCallFunctionArguments) String() string {
type Tool struct {
Type string `json:"type"`
Items any `json:"items,omitempty"`
Function ToolFunction `json:"function"`
}
// PropertyType can be either a string or an array of strings
type PropertyType []string
// UnmarshalJSON implements the json.Unmarshaler interface
func (pt *PropertyType) UnmarshalJSON(data []byte) error {
// Try to unmarshal as a string first
var s string
if err := json.Unmarshal(data, &s); err == nil {
*pt = []string{s}
return nil
}
// If that fails, try to unmarshal as an array of strings
var a []string
if err := json.Unmarshal(data, &a); err != nil {
return err
}
*pt = a
return nil
}
// MarshalJSON implements the json.Marshaler interface
func (pt PropertyType) MarshalJSON() ([]byte, error) {
if len(pt) == 1 {
// If there's only one type, marshal as a string
return json.Marshal(pt[0])
}
// Otherwise marshal as an array
return json.Marshal([]string(pt))
}
// String returns a string representation of the PropertyType
func (pt PropertyType) String() string {
if len(pt) == 0 {
return ""
}
if len(pt) == 1 {
return pt[0]
}
return fmt.Sprintf("%v", []string(pt))
}
type ToolFunction struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]struct {
Type PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
} `json:"parameters"`
}
@@ -307,7 +258,7 @@ type EmbedRequest struct {
Truncate *bool `json:"truncate,omitempty"`
// Options lists model-specific options.
Options map[string]any `json:"options"`
Options map[string]interface{} `json:"options"`
}
// EmbedResponse is the response from [Client.Embed].
@@ -333,7 +284,7 @@ type EmbeddingRequest struct {
KeepAlive *Duration `json:"keep_alive,omitempty"`
// Options lists model-specific options.
Options map[string]any `json:"options"`
Options map[string]interface{} `json:"options"`
}
// EmbeddingResponse is the response from [Client.Embeddings].
@@ -379,7 +330,7 @@ type ShowRequest struct {
Template string `json:"template"`
Verbose bool `json:"verbose"`
Options map[string]any `json:"options"`
Options map[string]interface{} `json:"options"`
// Deprecated: set the model name with Model instead
Name string `json:"name"`
@@ -387,18 +338,16 @@ type ShowRequest struct {
// ShowResponse is the response returned from [Client.Show].
type ShowResponse struct {
License string `json:"license,omitempty"`
Modelfile string `json:"modelfile,omitempty"`
Parameters string `json:"parameters,omitempty"`
Template string `json:"template,omitempty"`
System string `json:"system,omitempty"`
Details ModelDetails `json:"details,omitempty"`
Messages []Message `json:"messages,omitempty"`
ModelInfo map[string]any `json:"model_info,omitempty"`
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
Tensors []Tensor `json:"tensors,omitempty"`
Capabilities []model.Capability `json:"capabilities,omitempty"`
ModifiedAt time.Time `json:"modified_at,omitempty"`
License string `json:"license,omitempty"`
Modelfile string `json:"modelfile,omitempty"`
Parameters string `json:"parameters,omitempty"`
Template string `json:"template,omitempty"`
System string `json:"system,omitempty"`
Details ModelDetails `json:"details,omitempty"`
Messages []Message `json:"messages,omitempty"`
ModelInfo map[string]any `json:"model_info,omitempty"`
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
ModifiedAt time.Time `json:"modified_at,omitempty"`
}
// CopyRequest is the request passed to [Client.Copy].
@@ -410,9 +359,9 @@ type CopyRequest struct {
// PullRequest is the request passed to [Client.Pull].
type PullRequest struct {
Model string `json:"model"`
Insecure bool `json:"insecure,omitempty"` // Deprecated: ignored
Username string `json:"username"` // Deprecated: ignored
Password string `json:"password"` // Deprecated: ignored
Insecure bool `json:"insecure,omitempty"`
Username string `json:"username"`
Password string `json:"password"`
Stream *bool `json:"stream,omitempty"`
// Deprecated: set the model name with Model instead
@@ -516,13 +465,6 @@ type ModelDetails struct {
QuantizationLevel string `json:"quantization_level"`
}
// Tensor describes the metadata for a given tensor.
type Tensor struct {
Name string `json:"name"`
Type string `json:"type"`
Shape []uint64 `json:"shape"`
}
func (m *Metrics) Summary() {
if m.TotalDuration > 0 {
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
@@ -551,7 +493,7 @@ func (m *Metrics) Summary() {
}
}
func (opts *Options) FromMap(m map[string]any) error {
func (opts *Options) FromMap(m map[string]interface{}) error {
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
@@ -608,12 +550,12 @@ func (opts *Options) FromMap(m map[string]any) error {
}
field.SetString(val)
case reflect.Slice:
// JSON unmarshals to []any, not []string
val, ok := val.([]any)
// JSON unmarshals to []interface{}, not []string
val, ok := val.([]interface{})
if !ok {
return fmt.Errorf("option %q must be of type array", key)
}
// convert []any to []string
// convert []interface{} to []string
slice := make([]string, len(val))
for i, item := range val {
str, ok := item.(string)
@@ -667,7 +609,7 @@ func DefaultOptions() Options {
Runner: Runner{
// options set when the model is loaded
NumCtx: int(envconfig.ContextLength()),
NumCtx: 2048,
NumBatch: 512,
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
NumThread: 0, // let the runtime decide
@@ -720,7 +662,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
}
// FormatParams converts specified parameter options to their correct types
func FormatParams(params map[string][]string) (map[string]any, error) {
func FormatParams(params map[string][]string) (map[string]interface{}, error) {
opts := Options{}
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
@@ -734,7 +676,7 @@ func FormatParams(params map[string][]string) (map[string]any, error) {
}
}
out := make(map[string]any)
out := make(map[string]interface{})
// iterate params and set values based on json struct tags
for key, vals := range params {
if opt, ok := jsonOpts[key]; !ok {

View File

@@ -134,7 +134,7 @@ func TestUseMmapParsingFromJSON(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var oMap map[string]any
var oMap map[string]interface{}
err := json.Unmarshal([]byte(test.req), &oMap)
require.NoError(t, err)
opts := DefaultOptions()
@@ -231,144 +231,3 @@ func TestMessage_UnmarshalJSON(t *testing.T) {
}
}
}
func TestToolFunction_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
input string
wantErr string
}{
{
name: "valid enum with same types",
input: `{
"name": "test",
"description": "test function",
"parameters": {
"type": "object",
"required": ["test"],
"properties": {
"test": {
"type": "string",
"description": "test prop",
"enum": ["a", "b", "c"]
}
}
}
}`,
wantErr: "",
},
{
name: "empty enum array",
input: `{
"name": "test",
"description": "test function",
"parameters": {
"type": "object",
"required": ["test"],
"properties": {
"test": {
"type": "string",
"description": "test prop",
"enum": []
}
}
}
}`,
wantErr: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var tf ToolFunction
err := json.Unmarshal([]byte(tt.input), &tf)
if tt.wantErr != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.wantErr)
} else {
require.NoError(t, err)
}
})
}
}
func TestPropertyType_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
input string
expected PropertyType
}{
{
name: "string type",
input: `"string"`,
expected: PropertyType{"string"},
},
{
name: "array of types",
input: `["string", "number"]`,
expected: PropertyType{"string", "number"},
},
{
name: "array with single type",
input: `["string"]`,
expected: PropertyType{"string"},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var pt PropertyType
if err := json.Unmarshal([]byte(test.input), &pt); err != nil {
t.Errorf("Unexpected error: %v", err)
}
if len(pt) != len(test.expected) {
t.Errorf("Length mismatch: got %v, expected %v", len(pt), len(test.expected))
}
for i, v := range pt {
if v != test.expected[i] {
t.Errorf("Value mismatch at index %d: got %v, expected %v", i, v, test.expected[i])
}
}
})
}
}
func TestPropertyType_MarshalJSON(t *testing.T) {
tests := []struct {
name string
input PropertyType
expected string
}{
{
name: "single type",
input: PropertyType{"string"},
expected: `"string"`,
},
{
name: "multiple types",
input: PropertyType{"string", "number"},
expected: `["string","number"]`,
},
{
name: "empty type",
input: PropertyType{},
expected: `[]`,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
data, err := json.Marshal(test.input)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if string(data) != test.expected {
t.Errorf("Marshaled data mismatch: got %v, expected %v", string(data), test.expected)
}
})
}
}

View File

@@ -17,6 +17,6 @@ If you want to build the installer, youll need to install
In the top directory of this repo, run the following powershell script
to build the ollama CLI, ollama app, and ollama installer.
```powershell
```
powershell -ExecutionPolicy Bypass -File .\scripts\build_windows.ps1
```

View File

@@ -1,178 +0,0 @@
package benchmark
import (
"context"
"flag"
"fmt"
"testing"
"time"
"github.com/ollama/ollama/api"
)
// Command line flags
var modelFlag string
func init() {
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
flag.Lookup("m").DefValue = "model"
}
// modelName returns the model name from flags, failing the test if not set
func modelName(b *testing.B) string {
if modelFlag == "" {
b.Fatal("Error: -m flag is required for benchmark tests")
}
return modelFlag
}
type TestCase struct {
name string
prompt string
maxTokens int
}
// runGenerateBenchmark contains the common generate and metrics logic
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
start := time.Now()
var ttft time.Duration
var metrics api.Metrics
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
if ttft == 0 && resp.Response != "" {
ttft = time.Since(start)
}
if resp.Done {
metrics = resp.Metrics
}
return nil
})
// Report custom metrics as part of the benchmark results
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
// Token throughput metrics
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
b.ReportMetric(promptThroughput, "prompt_tok/s")
b.ReportMetric(genThroughput, "gen_tok/s")
// Token counts
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
if err != nil {
b.Fatal(err)
}
}
// BenchmarkColdStart runs benchmarks with model loading from cold state
func BenchmarkColdStart(b *testing.B) {
client := setup(b)
tests := []TestCase{
{"short_prompt", "Write a long story", 100},
{"medium_prompt", "Write a detailed economic analysis", 500},
{"long_prompt", "Write a comprehensive AI research paper", 1000},
}
m := modelName(b)
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
ctx := context.Background()
// Set number of tokens as our throughput metric
b.SetBytes(int64(tt.maxTokens))
for b.Loop() {
b.StopTimer()
// Ensure model is unloaded before each iteration
unload(client, m, b)
b.StartTimer()
req := &api.GenerateRequest{
Model: m,
Prompt: tt.prompt,
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
}
runGenerateBenchmark(b, ctx, client, req)
}
})
}
}
// BenchmarkWarmStart runs benchmarks with pre-loaded model
func BenchmarkWarmStart(b *testing.B) {
client := setup(b)
tests := []TestCase{
{"short_prompt", "Write a long story", 100},
{"medium_prompt", "Write a detailed economic analysis", 500},
{"long_prompt", "Write a comprehensive AI research paper", 1000},
}
m := modelName(b)
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
ctx := context.Background()
// Pre-warm the model
warmup(client, m, tt.prompt, b)
// Set number of tokens as our throughput metric
b.SetBytes(int64(tt.maxTokens))
for b.Loop() {
req := &api.GenerateRequest{
Model: m,
Prompt: tt.prompt,
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
}
runGenerateBenchmark(b, ctx, client, req)
}
})
}
}
// setup verifies server and model availability
func setup(b *testing.B) *api.Client {
client, err := api.ClientFromEnvironment()
if err != nil {
b.Fatal(err)
}
if _, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName(b)}); err != nil {
b.Fatalf("Model unavailable: %v", err)
}
return client
}
// warmup ensures the model is loaded and warmed up
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
for range 3 {
err := client.Generate(
context.Background(),
&api.GenerateRequest{
Model: model,
Prompt: prompt,
Options: map[string]any{"num_predict": 50, "temperature": 0.1},
},
func(api.GenerateResponse) error { return nil },
)
if err != nil {
b.Logf("Error during model warm-up: %v", err)
}
}
}
// unload forces model unloading using KeepAlive: 0 parameter
func unload(client *api.Client, model string, b *testing.B) {
req := &api.GenerateRequest{
Model: model,
KeepAlive: &api.Duration{Duration: 0},
}
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
b.Logf("Unload error: %v", err)
}
time.Sleep(1 * time.Second)
}

384
cache/cache.go vendored Normal file
View File

@@ -0,0 +1,384 @@
package cache
import (
"errors"
"fmt"
"log/slog"
"math"
"slices"
"github.com/ollama/ollama/ml"
)
var ErrNotSupported = errors.New("model does not support operation")
type Cache interface {
// used by model implementations
Sub(i int) Cache
Put(ctx ml.Context, key, value ml.Tensor) (ml.Tensor, ml.Tensor, ml.Tensor)
// cache management
Close()
StartForward(ctx ml.Context, seqs []int) error
CopyPrefix(srcSeq, dstSeq int, len int)
Remove(seq int, beginIndex, endIndex int) error
}
type Causal struct {
DType ml.DType
Capacity int
// current forward pass
curLayer int
curPos int
curBatchSize int
curMask ml.Tensor
curCellRange cellRange
// metadata
cells []cacheCell
seqNextPos map[int]int
cellRanges map[int]cellRange
// cache data storage
backend ml.Backend
cacheCtx ml.Context
keys, values []ml.Tensor
}
type seqCell struct {
seq int
pos int
}
type cacheCell struct {
sequences []seqCell
}
type cellRange struct {
min int
max int
}
func (cell cacheCell) findSeq(seq int) *seqCell {
for i := range cell.sequences {
if cell.sequences[i].seq == seq {
return &cell.sequences[i]
}
}
return nil
}
func NewCausalCache(backend ml.Backend, capacity int, dtype ml.DType) Cache {
return &Causal{
Capacity: capacity,
DType: dtype,
cells: make([]cacheCell, capacity),
seqNextPos: make(map[int]int),
cellRanges: make(map[int]cellRange),
backend: backend,
// TODO(jessegross): This context is not sized appropriately
cacheCtx: backend.NewContext(),
}
}
func (c *Causal) Close() {
c.cacheCtx.Close()
}
var ErrKvCacheFull = errors.New("could not find a kv cache slot")
func (c *Causal) StartForward(ctx ml.Context, seqs []int) error {
c.curBatchSize = len(seqs)
var err error
c.curPos, err = c.findStartPos()
if errors.Is(err, ErrKvCacheFull) {
c.defrag()
c.curPos, err = c.findStartPos()
}
if err != nil {
return err
}
// TODO(jessegross): There should be a better way to do this
origSeq := make(map[int]int)
for k, v := range c.seqNextPos {
origSeq[k] = v
}
c.curCellRange = newRange()
for i, seq := range seqs {
c.cells[c.curPos+i] = cacheCell{sequences: []seqCell{{seq: seq, pos: c.seqNextPos[seq]}}}
c.seqNextPos[seq]++
ranges := c.cellRanges[seq]
if c.curPos+i > ranges.max {
ranges.max = c.curPos + i
}
if ranges.max > c.curCellRange.max {
c.curCellRange.max = ranges.max
}
if c.curPos+i < ranges.min {
ranges.min = c.curPos + i
}
if ranges.min < c.curCellRange.min {
c.curCellRange.min = ranges.min
}
c.cellRanges[seq] = ranges
}
c.curMask, err = c.buildMask(ctx, origSeq, seqs)
return err
}
func newRange() cellRange {
return cellRange{
min: math.MaxInt,
max: 0,
}
}
func (c *Causal) findStartPos() (int, error) {
var start, count int
for i := range c.cells {
if len(c.cells[i].sequences) == 0 {
count++
if count >= c.curBatchSize {
return start, nil
}
} else {
start = i + 1
count = 0
}
}
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
}
func (c *Causal) buildMask(ctx ml.Context, origSeq map[int]int, seqs []int) (ml.Tensor, error) {
// TODO(jessegross): This makes a number of simplifications such as no padding
len := c.curCellRange.max - c.curCellRange.min
mask := make([]float32, c.curBatchSize*len)
for i := range c.curBatchSize {
for j := c.curCellRange.min; j < c.curCellRange.max; j++ {
cellSeq := c.cells[j].findSeq(seqs[i])
if cellSeq == nil || cellSeq.pos > origSeq[seqs[i]]+i {
mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1))
}
}
}
return ctx.FromFloatSlice(mask, len, c.curBatchSize)
}
func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) {
for _, obj := range objs {
srcView := obj.View(ctx, int(obj.Stride(2))*src, int(obj.Dim(0)*obj.Dim(1))*len)
dstView := obj.View(ctx, int(obj.Stride(2))*dst, int(obj.Dim(0)*obj.Dim(1))*len)
ctx.Forward(srcView.Copy(ctx, dstView))
}
}
func (c *Causal) defrag() {
slog.Debug("defragmenting kv cache")
// Defrag strategy:
// - Search for empty holes at the beginning of the cache,
// filling them with active data starting at the end
// - If there are contiguous elements that need to be moved,
// combine them into a single operation by holding new moves
// until we see the next one is non-contiguous
// - Fill up the context with the maximum number of operations it
// can hold then compute that and continue with a new context
// TODO(jessegross):
// - Need to size the context and compute maxMoves correctly
// - Just compacts, doesn't optimize placement
maxMoves := 8192 / (6 * len(c.keys))
ctx := c.backend.NewContext()
moves := 0
var pendingSrc, pendingDst, pendingLen int
for dst := range c.cells {
if len(c.cells[dst].sequences) == 0 {
for src := len(c.cells) - 1; src > dst; src-- {
if len(c.cells[src].sequences) != 0 {
c.cells[dst] = c.cells[src]
c.cells[src] = cacheCell{}
if pendingLen > 0 {
if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen {
pendingSrc = src
pendingLen++
break
} else {
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
moves++
}
}
pendingSrc = src
pendingDst = dst
pendingLen = 1
break
}
}
}
if moves >= maxMoves {
ctx.Compute(nil)
ctx.Close()
ctx = c.backend.NewContext()
moves = 0
}
}
if pendingLen > 0 {
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
moves++
}
if moves > 0 {
ctx.Compute(nil)
}
ctx.Close()
for seq := range c.cellRanges {
seqRange := newRange()
for i, cell := range c.cells {
if cell.findSeq(seq) != nil {
if i < seqRange.min {
seqRange.min = i
}
if i > seqRange.max {
seqRange.max = i
}
}
}
c.cellRanges[seq] = seqRange
}
}
func (c *Causal) Sub(i int) Cache {
if i >= len(c.keys) {
c.keys = append(c.keys, make([]ml.Tensor, i-len(c.keys)+1)...)
c.values = append(c.values, make([]ml.Tensor, i-len(c.values)+1)...)
}
c.curLayer = i
return c
}
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) (ml.Tensor, ml.Tensor, ml.Tensor) {
if c.curBatchSize != int(key.Dim(2)) {
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, int(key.Dim(2))))
}
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int64(c.Capacity))
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int64(c.Capacity))
}
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, int(key.Stride(2))*c.curPos, int(key.Dim(0)*key.Dim(1)*key.Dim(2)))))
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, int(value.Stride(2))*c.curPos, int(value.Dim(0)*value.Dim(1)*value.Dim(2)))))
len := c.curCellRange.max - c.curCellRange.min
key = c.keys[c.curLayer].View(ctx, int(key.Stride(2))*c.curCellRange.min,
int(key.Dim(0)), int(key.Stride(1)),
int(key.Dim(1)), int(key.Stride(2)),
len,
)
value = c.values[c.curLayer].View(ctx, int(key.Stride(2))*c.curCellRange.min,
int(value.Dim(0)), int(value.Stride(1)),
int(value.Dim(1)), int(value.Stride(2)),
len,
)
return key, value, c.curMask
}
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int) {
seqRange := newRange()
for i := range c.cells {
srcCellSeq := c.cells[i].findSeq(srcSeq)
dstCellSeq := c.cells[i].findSeq(dstSeq)
if dstCellSeq != nil {
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s seqCell) bool { return s.seq == dstSeq })
}
if srcCellSeq != nil && srcCellSeq.pos < len {
c.cells[i].sequences = append(c.cells[i].sequences, seqCell{seq: dstSeq, pos: srcCellSeq.pos})
if i < seqRange.min {
seqRange.min = i
}
if i > seqRange.max {
seqRange.max = i
}
}
}
c.cellRanges[dstSeq] = seqRange
c.seqNextPos[dstSeq] = len
}
func (c *Causal) shift(seq int, beginIndex, endIndex, offset int) error {
panic("Shift not yet implemented")
}
func (c *Causal) Remove(seq int, beginIndex, endIndex int) error {
endIndex = min(endIndex, c.seqNextPos[seq])
offset := beginIndex - endIndex
seqRange := newRange()
for i := range c.cells {
cellSeq := c.cells[i].findSeq(seq)
if cellSeq != nil {
if cellSeq.pos >= beginIndex && cellSeq.pos < endIndex {
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s seqCell) bool { return s.seq == seq })
} else {
if cellSeq.pos >= endIndex {
cellSeq.pos += offset
}
if i < seqRange.min {
seqRange.min = i
}
if i > seqRange.max {
seqRange.max = i
}
}
}
}
if endIndex != c.seqNextPos[seq] {
err := c.shift(seq, endIndex, c.seqNextPos[seq], offset)
if err != nil {
return err
}
}
c.cellRanges[seq] = seqRange
c.seqNextPos[seq] += offset
return nil
}

48
cache/tensor.go vendored Normal file
View File

@@ -0,0 +1,48 @@
package cache
import (
"github.com/ollama/ollama/ml"
)
type TensorCache struct {
curLayer int
cacheCtx ml.Context
keys, values []ml.Tensor
}
func NewTensorCache(backend ml.Backend) *TensorCache {
return &TensorCache{
// TODO(jessegross): This context is not sized appropriately
cacheCtx: backend.NewContext(),
}
}
func (c *TensorCache) Close() {
c.cacheCtx.Close()
}
func (c *TensorCache) Sub(i int) *TensorCache {
if i >= len(c.keys) {
c.keys = append(c.keys, make([]ml.Tensor, i-len(c.keys)+1)...)
c.values = append(c.values, make([]ml.Tensor, i-len(c.values)+1)...)
}
c.curLayer = i
return c
}
func (c *TensorCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
return c.keys[c.curLayer], c.values[c.curLayer], nil
}
func (c *TensorCache) Put(ctx ml.Context, key, value ml.Tensor) {
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...)
c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)
}
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer]))
ctx.Forward(value.Copy(ctx, c.values[c.curLayer]))
}

View File

@@ -18,8 +18,6 @@ import (
"os/signal"
"path/filepath"
"runtime"
"slices"
"sort"
"strconv"
"strings"
"sync/atomic"
@@ -36,6 +34,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/runner"
@@ -257,7 +256,6 @@ func StopHandler(cmd *cobra.Command, args []string) error {
if strings.Contains(err.Error(), "not found") {
return fmt.Errorf("couldn't find model \"%s\" to stop", args[0])
}
return err
}
return nil
}
@@ -268,7 +266,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
opts := runOptions{
Model: args[0],
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]any{},
Options: map[string]interface{}{},
}
format, err := cmd.Flags().GetString("format")
@@ -340,21 +338,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return err
}
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
// TODO: remove the projector info and vision info checks below,
// these are left in for backwards compatibility with older servers
// that don't have the capabilities field in the model info
if len(info.ProjectorInfo) != 0 {
opts.MultiModal = true
}
for k := range info.ModelInfo {
if strings.Contains(k, ".vision.") {
opts.MultiModal = true
break
}
}
opts.MultiModal = true //len(info.ProjectorInfo) != 0
opts.ParentModel = info.Details.ParentModel
if interactive {
@@ -575,9 +559,8 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
parameters, errParams := cmd.Flags().GetBool("parameters")
system, errSystem := cmd.Flags().GetBool("system")
template, errTemplate := cmd.Flags().GetBool("template")
verbose, errVerbose := cmd.Flags().GetBool("verbose")
for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate, errVerbose} {
for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate} {
if boolErr != nil {
return errors.New("error retrieving flags")
}
@@ -615,7 +598,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
}
req := api.ShowRequest{Name: args[0], Verbose: verbose}
req := api.ShowRequest{Name: args[0]}
resp, err := client.Show(cmd.Context(), &req)
if err != nil {
return err
@@ -638,10 +621,10 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
return nil
}
return showInfo(resp, verbose, os.Stdout)
return showInfo(resp, os.Stdout)
}
func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
func showInfo(resp *api.ShowResponse, w io.Writer) error {
tableRender := func(header string, rows func() [][]string) {
fmt.Fprintln(w, " ", header)
table := tablewriter.NewWriter(w)
@@ -675,15 +658,6 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
return
})
if len(resp.Capabilities) > 0 {
tableRender("Capabilities", func() (rows [][]string) {
for _, capability := range resp.Capabilities {
rows = append(rows, []string{"", capability.String()})
}
return
})
}
if resp.ProjectorInfo != nil {
tableRender("Projector", func() (rows [][]string) {
arch := resp.ProjectorInfo["general.architecture"].(string)
@@ -707,47 +681,6 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
})
}
if resp.ModelInfo != nil && verbose {
tableRender("Metadata", func() (rows [][]string) {
keys := make([]string, 0, len(resp.ModelInfo))
for k := range resp.ModelInfo {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
var v string
switch vData := resp.ModelInfo[k].(type) {
case bool:
v = fmt.Sprintf("%t", vData)
case string:
v = vData
case float64:
v = fmt.Sprintf("%g", vData)
case []any:
n := 3
if len(vData) < n {
n = len(vData)
}
v = fmt.Sprintf("%v", vData[:n])
default:
v = fmt.Sprintf("%T", vData)
}
rows = append(rows, []string{"", k, v})
}
return
})
}
if len(resp.Tensors) > 0 && verbose {
tableRender("Tensors", func() (rows [][]string) {
for _, t := range resp.Tensors {
rows = append(rows, []string{"", t.Name, t.Type, fmt.Sprint(t.Shape)})
}
return
})
}
head := func(s string, n int) (rows [][]string) {
scanner := bufio.NewScanner(strings.NewReader(s))
for scanner.Scan() && (len(rows) < n || n < 0) {
@@ -808,38 +741,13 @@ func PullHandler(cmd *cobra.Command, args []string) error {
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
if resp.Completed == 0 {
// This is the initial status update for the
// layer, which the server sends before
// beginning the download, for clients to
// compute total size and prepare for
// downloads, if needed.
//
// Skipping this here to avoid showing a 0%
// progress bar, which *should* clue the user
// into the fact that many things are being
// downloaded and that the current active
// download is not that last. However, in rare
// cases it seems to be triggering to some, and
// it isn't worth explaining, so just ignore
// and regress to the old UI that keeps giving
// you the "But wait, there is more!" after
// each "100% done" bar, which is "better."
return nil
}
if spinner != nil {
spinner.Stop()
}
bar, ok := bars[resp.Digest]
if !ok {
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
name = strings.TrimSpace(name)
if isDigest {
name = name[:min(12, len(name))]
}
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
@@ -859,7 +767,11 @@ func PullHandler(cmd *cobra.Command, args []string) error {
}
request := api.PullRequest{Name: args[0], Insecure: insecure}
return client.Pull(cmd.Context(), &request, fn)
if err := client.Pull(cmd.Context(), &request, fn); err != nil {
return err
}
return nil
}
type generateContextKey string
@@ -873,7 +785,7 @@ type runOptions struct {
Format string
System string
Images []api.ImageData
Options map[string]any
Options map[string]interface{}
MultiModal bool
KeepAlive *api.Duration
}
@@ -1275,7 +1187,6 @@ func NewCLI() *cobra.Command {
showCmd.Flags().Bool("parameters", false, "Show parameters of a model")
showCmd.Flags().Bool("template", false, "Show template of a model")
showCmd.Flags().Bool("system", false, "Show system message of a model")
showCmd.Flags().BoolP("verbose", "v", false, "Show detailed model information")
runCmd := &cobra.Command{
Use: "run MODEL [PROMPT]",
@@ -1360,6 +1271,7 @@ func NewCLI() *cobra.Command {
runnerCmd := &cobra.Command{
Use: "runner",
Short: llama.PrintSystemInfo(),
Hidden: true,
RunE: func(cmd *cobra.Command, args []string) error {
return runner.Execute(os.Args[1:])
@@ -1402,12 +1314,12 @@ func NewCLI() *cobra.Command {
envVars["OLLAMA_NOPRUNE"],
envVars["OLLAMA_ORIGINS"],
envVars["OLLAMA_SCHED_SPREAD"],
envVars["OLLAMA_TMPDIR"],
envVars["OLLAMA_FLASH_ATTENTION"],
envVars["OLLAMA_KV_CACHE_TYPE"],
envVars["OLLAMA_LLM_LIBRARY"],
envVars["OLLAMA_GPU_OVERHEAD"],
envVars["OLLAMA_LOAD_TIMEOUT"],
envVars["OLLAMA_CONTEXT_LENGTH"],
})
default:
appendEnvDocs(cmd, envs)

View File

@@ -10,13 +10,11 @@ import (
"os"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/spf13/cobra"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
func TestShowInfo(t *testing.T) {
@@ -28,7 +26,7 @@ func TestShowInfo(t *testing.T) {
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
}, false, &b); err != nil {
}, &b); err != nil {
t.Fatal(err)
}
@@ -58,7 +56,7 @@ func TestShowInfo(t *testing.T) {
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
}, false, &b); err != nil {
}, &b); err != nil {
t.Fatal(err)
}
@@ -69,60 +67,6 @@ func TestShowInfo(t *testing.T) {
embedding length 0
quantization FP16
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
t.Run("verbose model", func(t *testing.T) {
var b bytes.Buffer
if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "test",
ParameterSize: "8B",
QuantizationLevel: "FP16",
},
Parameters: `
stop up`,
ModelInfo: map[string]any{
"general.architecture": "test",
"general.parameter_count": float64(8_000_000_000),
"some.true_bool": true,
"some.false_bool": false,
"test.context_length": float64(1000),
"test.embedding_length": float64(11434),
},
Tensors: []api.Tensor{
{Name: "blk.0.attn_k.weight", Type: "BF16", Shape: []uint64{42, 3117}},
{Name: "blk.0.attn_q.weight", Type: "FP16", Shape: []uint64{3117, 42}},
},
}, true, &b); err != nil {
t.Fatal(err)
}
expect := ` Model
architecture test
parameters 8B
context length 1000
embedding length 11434
quantization FP16
Parameters
stop up
Metadata
general.architecture test
general.parameter_count 8e+09
some.false_bool false
some.true_bool true
test.context_length 1000
test.embedding_length 11434
Tensors
blk.0.attn_k.weight BF16 [42 3117]
blk.0.attn_q.weight FP16 [3117 42]
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
@@ -144,7 +88,7 @@ func TestShowInfo(t *testing.T) {
stop you
stop up
temperature 99`,
}, false, &b); err != nil {
}, &b); err != nil {
t.Fatal(err)
}
@@ -181,7 +125,7 @@ func TestShowInfo(t *testing.T) {
"clip.vision.embedding_length": float64(0),
"clip.vision.projection_dim": float64(0),
},
}, false, &b); err != nil {
}, &b); err != nil {
t.Fatal(err)
}
@@ -214,7 +158,7 @@ func TestShowInfo(t *testing.T) {
Ahoy, matey!
Weigh anchor!
`,
}, false, &b); err != nil {
}, &b); err != nil {
t.Fatal(err)
}
@@ -243,7 +187,7 @@ Weigh anchor!
QuantizationLevel: "FP16",
},
License: license,
}, false, &b); err != nil {
}, &b); err != nil {
t.Fatal(err)
}
@@ -261,34 +205,6 @@ Weigh anchor!
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
t.Run("capabilities", func(t *testing.T) {
var b bytes.Buffer
if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "test",
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
Capabilities: []model.Capability{model.CapabilityVision, model.CapabilityTools},
}, false, &b); err != nil {
t.Fatal(err)
}
expect := " Model\n" +
" architecture test \n" +
" parameters 7B \n" +
" quantization FP16 \n" +
"\n" +
" Capabilities\n" +
" vision \n" +
" tools \n" +
"\n"
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
}
func TestDeleteHandler(t *testing.T) {
@@ -415,7 +331,6 @@ func TestGetModelfileName(t *testing.T) {
if err != nil {
t.Fatalf("temp modelfile creation failed: %v", err)
}
defer tempFile.Close()
expectedFilename = tempFile.Name()
err = cmd.Flags().Set("file", expectedFilename)
@@ -575,96 +490,6 @@ func TestPushHandler(t *testing.T) {
}
}
func TestListHandler(t *testing.T) {
tests := []struct {
name string
args []string
serverResponse []api.ListModelResponse
expectedError string
expectedOutput string
}{
{
name: "list all models",
args: []string{},
serverResponse: []api.ListModelResponse{
{Name: "model1", Digest: "sha256:abc123", Size: 1024, ModifiedAt: time.Now().Add(-24 * time.Hour)},
{Name: "model2", Digest: "sha256:def456", Size: 2048, ModifiedAt: time.Now().Add(-48 * time.Hour)},
},
expectedOutput: "NAME ID SIZE MODIFIED \n" +
"model1 sha256:abc12 1.0 KB 24 hours ago \n" +
"model2 sha256:def45 2.0 KB 2 days ago \n",
},
{
name: "filter models by prefix",
args: []string{"model1"},
serverResponse: []api.ListModelResponse{
{Name: "model1", Digest: "sha256:abc123", Size: 1024, ModifiedAt: time.Now().Add(-24 * time.Hour)},
{Name: "model2", Digest: "sha256:def456", Size: 2048, ModifiedAt: time.Now().Add(-24 * time.Hour)},
},
expectedOutput: "NAME ID SIZE MODIFIED \n" +
"model1 sha256:abc12 1.0 KB 24 hours ago \n",
},
{
name: "server error",
args: []string{},
expectedError: "server error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/tags" || r.Method != http.MethodGet {
t.Errorf("unexpected request to %s %s", r.Method, r.URL.Path)
http.Error(w, "not found", http.StatusNotFound)
return
}
if tt.expectedError != "" {
http.Error(w, tt.expectedError, http.StatusInternalServerError)
return
}
response := api.ListResponse{Models: tt.serverResponse}
if err := json.NewEncoder(w).Encode(response); err != nil {
t.Fatal(err)
}
}))
defer mockServer.Close()
t.Setenv("OLLAMA_HOST", mockServer.URL)
cmd := &cobra.Command{}
cmd.SetContext(context.TODO())
// Capture stdout
oldStdout := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
err := ListHandler(cmd, tt.args)
// Restore stdout and get output
w.Close()
os.Stdout = oldStdout
output, _ := io.ReadAll(r)
if tt.expectedError == "" {
if err != nil {
t.Errorf("expected no error, got %v", err)
}
if got := string(output); got != tt.expectedOutput {
t.Errorf("expected output:\n%s\ngot:\n%s", tt.expectedOutput, got)
}
} else {
if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
}
}
})
}
}
func TestCreateHandler(t *testing.T) {
tests := []struct {
name string
@@ -791,132 +616,3 @@ func TestCreateHandler(t *testing.T) {
})
}
}
func TestNewCreateRequest(t *testing.T) {
tests := []struct {
name string
from string
opts runOptions
expected *api.CreateRequest
}{
{
"basic test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "",
Prompt: "You are a fun AI agent",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "mymodel",
Model: "newmodel",
},
},
{
"parent model test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "parentmodel",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "parentmodel",
Model: "newmodel",
},
},
{
"parent model as filepath test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "/some/file/like/etc/passwd",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "mymodel",
Model: "newmodel",
},
},
{
"parent model as windows filepath test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "D:\\some\\file\\like\\etc\\passwd",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "mymodel",
Model: "newmodel",
},
},
{
"options test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "parentmodel",
Options: map[string]any{
"temperature": 1.0,
},
},
&api.CreateRequest{
From: "parentmodel",
Model: "newmodel",
Parameters: map[string]any{
"temperature": 1.0,
},
},
},
{
"messages test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "parentmodel",
System: "You are a fun AI agent",
Messages: []api.Message{
{
Role: "user",
Content: "hello there!",
},
{
Role: "assistant",
Content: "hello to you!",
},
},
WordWrap: true,
},
&api.CreateRequest{
From: "parentmodel",
Model: "newmodel",
System: "You are a fun AI agent",
Messages: []api.Message{
{
Role: "user",
Content: "hello there!",
},
{
Role: "assistant",
Content: "hello to you!",
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := NewCreateRequest(tt.from, tt.opts)
if !cmp.Equal(actual, tt.expected) {
t.Errorf("expected output %#v, got %#v", tt.expected, actual)
}
})
}
}

View File

@@ -18,7 +18,6 @@ import (
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
)
type MultilineState int
@@ -196,10 +195,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
opts.Messages = []api.Message{}
fmt.Printf("Loading model '%s'\n", opts.Model)
if err := loadOrUnloadModel(cmd, &opts); err != nil {
if strings.Contains(err.Error(), "not found") {
fmt.Printf("error: %v\n", err)
continue
}
return err
}
continue
@@ -348,7 +343,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
switch args[1] {
case "info":
_ = showInfo(resp, false, os.Stderr)
_ = showInfo(resp, os.Stderr)
case "license":
if resp.License == "" {
fmt.Println("No license was specified for this model.")
@@ -460,16 +455,9 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
}
func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
parentModel := opts.ParentModel
modelName := model.ParseName(parentModel)
if !modelName.IsValid() {
parentModel = ""
}
req := &api.CreateRequest{
Model: name,
From: cmp.Or(parentModel, opts.Model),
Name: name,
From: cmp.Or(opts.ParentModel, opts.Model),
}
if opts.System != "" {
@@ -503,7 +491,6 @@ func normalizeFilePath(fp string) string {
"\\\\", "\\", // Escaped backslash
"\\*", "*", // Escaped asterisk
"\\?", "?", // Escaped question mark
"\\~", "~", // Escaped tilde
).Replace(fp)
}

View File

@@ -7,20 +7,14 @@ import (
"io"
"io/fs"
"log/slog"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type ModelParameters struct {
Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"`
TextModel TextParameters `json:"text_config"`
}
type TextParameters struct {
VocabSize uint32 `json:"vocab_size"`
Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"`
}
type AdapterParameters struct {
@@ -85,6 +79,14 @@ func (ModelParameters) specialTokenTypes() []string {
}
}
func (ModelParameters) writeFile(ws io.WriteSeeker, kv ggml.KV, ts []ggml.Tensor) error {
return ggml.WriteGGUF(ws, kv, ts)
}
func (AdapterParameters) writeFile(ws io.WriteSeeker, kv ggml.KV, ts []ggml.Tensor) error {
return ggml.WriteGGUF(ws, kv, ts)
}
type ModelConverter interface {
// KV maps parameters to LLM key-values
KV(*Tokenizer) ggml.KV
@@ -96,6 +98,8 @@ type ModelConverter interface {
// specialTokenTypes returns any special token types the model uses
specialTokenTypes() []string
// writeFile writes the model to the provided io.WriteSeeker
writeFile(io.WriteSeeker, ggml.KV, []ggml.Tensor) error
}
type moreParser interface {
@@ -110,6 +114,8 @@ type AdapterConverter interface {
// Replacements returns a list of string pairs to replace in tensor names.
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
Replacements() []string
writeFile(io.WriteSeeker, ggml.KV, []ggml.Tensor) error
}
func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV ggml.KV) error {
@@ -147,7 +153,7 @@ func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV ggml.KV) error {
return err
}
return writeFile(ws, conv.KV(baseKV), conv.Tensors(ts))
return conv.writeFile(ws, conv.KV(baseKV), conv.Tensors(ts))
}
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
@@ -171,20 +177,14 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
var conv ModelConverter
switch p.Architectures[0] {
case "LlamaForCausalLM":
case "LlamaForCausalLM", "MistralForCausalLM":
conv = &llamaModel{}
case "Llama4ForConditionalGeneration":
conv = &llama4Model{}
case "Mistral3ForConditionalGeneration":
conv = &mistral3Model{}
case "MixtralForCausalLM":
conv = &mixtralModel{}
case "GemmaForCausalLM":
conv = &gemmaModel{}
case "Gemma2ForCausalLM":
conv = &gemma2Model{}
case "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration":
conv = &gemma3Model{Architecture: p.Architectures[0]}
case "Phi3ForCausalLM":
conv = &phi3Model{}
case "Qwen2ForCausalLM":
@@ -194,7 +194,7 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
case "CohereForCausalLM":
conv = &commandrModel{}
default:
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
return errors.New("unsupported architecture")
}
if err := json.Unmarshal(bts, conv); err != nil {
@@ -213,14 +213,7 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
}
vocabSize := int(p.VocabSize)
if vocabSize == 0 {
tVocabSize := int(p.TextModel.VocabSize)
vocabSize = tVocabSize
}
switch {
case vocabSize == 0:
slog.Warn("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens))
case vocabSize > len(t.Vocabulary.Tokens):
slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
for i := range vocabSize - len(t.Vocabulary.Tokens) {
@@ -239,13 +232,5 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
return err
}
return writeFile(ws, conv.KV(t), conv.Tensors(ts))
}
func writeFile(ws io.WriteSeeker, kv ggml.KV, ts []ggml.Tensor) error {
for i := range ts {
ts[i].Shape = slices.Clone(ts[i].Shape)
slices.Reverse(ts[i].Shape)
}
return ggml.WriteGGUF(ws, kv, ts)
return conv.writeFile(ws, conv.KV(t), conv.Tensors(ts))
}

View File

@@ -45,7 +45,7 @@ func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
func (p *gemmaModel) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
for _, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") && strings.HasSuffix(t.Name(), "_norm.weight") {
if strings.HasSuffix(t.Name(), "_norm.weight") {
t.SetRepacker(p.addOne)
}

View File

@@ -1,142 +0,0 @@
package convert
import (
"cmp"
"github.com/ollama/ollama/fs/ggml"
)
type gemma3Model struct {
gemmaModel
Architecture string
TextModel struct {
HeadDim uint32 `json:"head_dim"`
HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
SlidingWindow uint32 `json:"sliding_window"`
} `json:"text_config"`
VisionModel struct {
NumAttentionHeads uint32 `json:"num_attention_heads"` // attention.head_count 16
LayerNormEpsilon float32 `json:"layer_norm_eps"` // attention.layer_norm_epsilon 1e-05
NumHiddenLayers uint32 `json:"num_hidden_layers"` // block_count 32
HiddenSize uint32 `json:"hidden_size"` // embedding_length 1280
IntermediateSize uint32 `json:"intermediate_size"` // feed_forward_length 5120
ImageSize uint32 `json:"image_size"` // image_size 560
NumChannels uint32 `json:"num_channels"` // num_channels 3
PatchSize uint32 `json:"patch_size"` // patch_size 14
} `json:"vision_config"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
RopeLocalTheta float32 `json:"rope_local_base_freq"`
RopeGlobalTheta float32 `json:"rope_global_base_freq"`
SlidingWindow uint32 `json:"sliding_window"`
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
}
const (
gemma4BLayerCount = 34
gemma12BLayerCount = 48
gemma27BLayerCount = 62
)
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3"
numBlocks := cmp.Or(p.HiddenLayers, p.TextModel.HiddenLayers)
kv["gemma3.block_count"] = numBlocks
var (
numHeads uint32
numKVHeads uint32
)
switch numBlocks {
case gemma4BLayerCount:
numHeads = 8
numKVHeads = 4
case gemma12BLayerCount:
numHeads = 16
numKVHeads = 8
case gemma27BLayerCount:
numHeads = 32
numKVHeads = 16
default:
numHeads = p.NumAttentionHeads
numKVHeads = p.NumKeyValueHeads
}
kv["gemma3.attention.head_count"] = numHeads
kv["gemma3.attention.head_count_kv"] = numKVHeads
switch p.Architecture {
case "Gemma3ForCausalLM":
kv["gemma3.context_length"] = p.MaxPositionEmbeddings
kv["gemma3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["gemma3.attention.key_length"] = p.HeadDim
kv["gemma3.attention.value_length"] = p.HeadDim
kv["gemma3.attention.sliding_window"] = p.SlidingWindow
kv["gemma3.final_logit_softcapping"] = cmp.Or(p.FinalLogitSoftcap, 30)
kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0)
kv["gemma3.rope.global.freq_base"] = cmp.Or(p.RopeGlobalTheta, 1000000.0)
kv["gemma3.embedding_length"] = p.HiddenSize
kv["gemma3.feed_forward_length"] = p.IntermediateSize
default:
kv["gemma3.context_length"] = cmp.Or(p.MaxPositionEmbeddings, 131072)
kv["gemma3.embedding_length"] = p.TextModel.HiddenSize
kv["gemma3.feed_forward_length"] = p.TextModel.IntermediateSize
kv["gemma3.attention.sliding_window"] = p.TextModel.SlidingWindow
kv["gemma3.vision.block_count"] = p.VisionModel.NumHiddenLayers
kv["gemma3.vision.embedding_length"] = p.VisionModel.HiddenSize
kv["gemma3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
kv["gemma3.vision.image_size"] = p.VisionModel.ImageSize
kv["gemma3.vision.patch_size"] = p.VisionModel.PatchSize
kv["gemma3.vision.num_channels"] = cmp.Or(p.VisionModel.NumChannels, 3)
kv["gemma3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
kv["gemma3.vision.attention.layer_norm_epsilon"] = cmp.Or(p.VisionModel.LayerNormEpsilon, 1e-6)
kv["gemma3.attention.key_length"] = cmp.Or(p.TextModel.HeadDim, 256)
kv["gemma3.attention.value_length"] = cmp.Or(p.TextModel.HeadDim, 256)
}
if p.MultiModalTokensPerImage > 0 {
kv["gemma3.mm.tokens_per_image"] = p.MultiModalTokensPerImage
}
return kv
}
func (p *gemma3Model) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"vision_tower.vision_model.embeddings", "v",
"vision_tower.vision_model", "v",
"vision_model.vision_model.embeddings", "v",
"vision_model.vision_model", "v",
"language_model.", "",
"model.layers", "blk",
"encoder.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.q_norm", "attn_q_norm",
"self_attn.k_proj", "attn_k",
"self_attn.k_norm", "attn_k_norm",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"self_attn.out_proj", "attn_output",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
"post_attention_layernorm", "post_attention_norm",
"pre_feedforward_layernorm", "ffn_norm",
"post_feedforward_layernorm", "post_ffw_norm",
"input_projection_weight", "input_projection.weight",
"multi_modal_projector", "mm",
}
}

View File

@@ -28,12 +28,12 @@ type llamaModel struct {
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
RopeScaling struct {
Type string `json:"type"`
RopeType string `json:"rope_type"`
Factor float32 `json:"factor"`
LowFrequencyFactor float32 `json:"low_freq_factor"`
HighFrequencyFactor float32 `json:"high_freq_factor"`
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
Type string `json:"type"`
RopeType string `json:"rope_type"`
Factor float32 `json:"factor"`
LowFrequencyFactor float32 `json:"low_freq_factor"`
HighFrequencyFactor float32 `json:"high_freq_factor"`
OriginalMaxPositionalEmbeddings uint32 `json:"original_max_positional_embeddings"`
factors ropeFactor
} `json:"rope_scaling"`
@@ -42,8 +42,6 @@ type llamaModel struct {
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
NormEpsilon float32 `json:"norm_epsilon"`
HeadDim uint32 `json:"head_dim"`
skipRepack bool
}
var _ ModelConverter = (*llamaModel)(nil)
@@ -72,10 +70,6 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
kv["llama.rope.dimension_count"] = p.HiddenSize / headCount
}
if p.HeadDim > 0 {
kv["llama.attention.head_dim"] = p.HeadDim
}
if p.RopeTheta > 0 {
kv["llama.rope.freq_base"] = p.RopeTheta
}
@@ -90,7 +84,7 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
factorLow := cmp.Or(p.RopeScaling.LowFrequencyFactor, 1.0)
factorHigh := cmp.Or(p.RopeScaling.HighFrequencyFactor, 4.0)
original := cmp.Or(p.RopeScaling.OriginalMaxPositionEmbeddings, 8192)
original := cmp.Or(p.RopeScaling.OriginalMaxPositionalEmbeddings, 8192)
lambdaLow := float32(original) / factorLow
lambdaHigh := float32(original) / factorHigh
@@ -139,10 +133,9 @@ func (p *llamaModel) Tensors(ts []Tensor) []ggml.Tensor {
}
for _, t := range ts {
if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") {
if !p.skipRepack {
t.SetRepacker(p.repack)
}
if strings.HasSuffix(t.Name(), "attn_q.weight") ||
strings.HasSuffix(t.Name(), "attn_k.weight") {
t.SetRepacker(p.repack)
}
out = append(out, ggml.Tensor{

View File

@@ -1,169 +0,0 @@
package convert
import (
"slices"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type llama4Model struct {
ModelParameters
TextModel struct {
llamaModel
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
NumLocalExperts uint32 `json:"num_local_experts"`
InterleaveMOELayerStep uint32 `json:"interleave_moe_layer_step"`
UseQKNorm bool `json:"use_qk_norm"`
IntermediateSizeMLP uint32 `json:"intermediate_size_mlp"`
AttentionChunkSize uint32 `json:"attention_chunk_size"`
} `json:"text_config"`
VisionModel struct {
NumHiddenLayers uint32 `json:"num_hidden_layers"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
ImageSize uint32 `json:"image_size"`
PatchSize uint32 `json:"patch_size"`
RopeTheta float32 `json:"rope_theta"`
NormEpsilon float32 `json:"norm_eps"`
PixelShuffleRatio float32 `json:"pixel_shuffle_ratio"`
} `json:"vision_config"`
}
// KV implements ModelConverter.
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama4"
for k, v := range p.TextModel.KV(t) {
if strings.HasPrefix(k, "llama.") {
kv[strings.ReplaceAll(k, "llama.", "llama4.")] = v
}
}
kv["llama4.feed_forward_length"] = p.TextModel.IntermediateSizeMLP
kv["llama4.expert_feed_forward_length"] = p.TextModel.IntermediateSize
kv["llama4.expert_count"] = p.TextModel.NumLocalExperts
kv["llama4.expert_used_count"] = p.TextModel.NumExpertsPerToken
kv["llama4.interleave_moe_layer_step"] = p.TextModel.InterleaveMOELayerStep
kv["llama4.use_qk_norm"] = p.TextModel.UseQKNorm
kv["llama4.attention.chunk_size"] = p.TextModel.AttentionChunkSize
kv["llama4.vision.block_count"] = p.VisionModel.NumHiddenLayers
kv["llama4.vision.embedding_length"] = p.VisionModel.HiddenSize
kv["llama4.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
kv["llama4.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
kv["llama4.vision.image_size"] = p.VisionModel.ImageSize
kv["llama4.vision.patch_size"] = p.VisionModel.PatchSize
kv["llama4.vision.rope.freq_base"] = p.VisionModel.RopeTheta
kv["llama4.vision.layer_norm_epsilon"] = p.VisionModel.NormEpsilon
kv["llama4.vision.pixel_shuffle_ratio"] = p.VisionModel.PixelShuffleRatio
return kv
}
// Replacements implements ModelConverter.
func (p *llama4Model) Replacements() []string {
return append(
p.TextModel.Replacements(),
"language_model.", "",
"vision_model", "v",
"multi_modal_projector", "mm",
"feed_forward.down_proj", "ffn_down",
"feed_forward.up_proj", "ffn_up",
"feed_forward.gate_proj", "ffn_gate",
"feed_forward.", "ffn_",
"shared_expert.down_proj", "down_shexp",
"shared_expert.gate_proj", "gate_shexp",
"shared_expert.up_proj", "up_shexp",
"experts.down_proj", "down_exps.weight",
"experts.gate_up_proj", "gate_up_exps.weight",
"router", "gate_inp",
"patch_embedding.linear", "patch_embedding",
)
}
// Tensors implements ModelConverter.
func (p *llama4Model) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
var textTensors []Tensor
for _, t := range ts {
if strings.HasPrefix(t.Name(), "v.") || strings.HasPrefix(t.Name(), "mm.") {
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
} else if strings.Contains(t.Name(), "ffn_gate_up_exps") {
// gate and up projectors are fused
// dims[1], dims[2] must be swapped
// [experts, hidden_size, intermediate_size * 2] --> [experts, intermediate_size, hidden_size]
halfDim := int(t.Shape()[2]) / 2
newShape := slices.Clone(t.Shape())
newShape[1], newShape[2] = newShape[2]/2, newShape[1]
for i, name := range []string{"ffn_gate_exps", "ffn_up_exps"} {
// clone tensor since we need separate repackers
tt := t.Clone()
tt.SetRepacker(p.repack(nil, nil, tensor.S(i*halfDim, (i+1)*halfDim)))
out = append(out, ggml.Tensor{
Name: strings.ReplaceAll(tt.Name(), "ffn_gate_up_exps", name),
Kind: tt.Kind(),
Shape: newShape,
WriterTo: tt,
})
}
} else if strings.Contains(t.Name(), "ffn_down_exps") {
// dims[1], dims[2] must be swapped
// [experts, intermediate_size, hidden_size] --> [experts, hidden_size, intermediate_size]
t.SetRepacker(p.repack())
newShape := slices.Clone(t.Shape())
newShape[1], newShape[2] = newShape[2], newShape[1]
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: newShape,
WriterTo: t,
})
} else {
textTensors = append(textTensors, t)
}
}
p.TextModel.skipRepack = true
out = append(out, p.TextModel.Tensors(textTensors)...)
return out
}
func (p *llama4Model) repack(slice ...tensor.Slice) Repacker {
return func(name string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i, dim := range shape {
dims[i] = int(dim)
}
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
t, err := t.Slice(slice...)
if err != nil {
return nil, err
}
if err := t.T(0, 2, 1); err != nil {
return nil, err
}
t = tensor.Materialize(t)
// flatten tensor so it can be return as a vector
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(t.(*tensor.Dense))
}
}

View File

@@ -1,190 +0,0 @@
package convert
import (
"cmp"
"fmt"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type mistral3Model struct {
ModelParameters
ImageTokenIndex uint32 `json:"image_token_index"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
VisionFeatureLayer int32 `json:"vision_feature_layer"`
TextModel struct {
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
SlidingWindow *uint32 `json:"sliding_window"`
HiddenAct string `json:"hidden_act"`
VocabSize uint32 `json:"vocab_size"`
} `json:"text_config"`
VisionModel struct {
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
ImageSize uint32 `json:"image_size"`
NumChannels uint32 `json:"num_channels"`
PatchSize uint32 `json:"patch_size"`
HeadDim uint32 `json:"head_dim"`
HiddenAct string `json:"hidden_act"`
RopeTheta float32 `json:"rope_theta"`
} `json:"vision_config"`
MultiModalProjectorBias bool `json:"multimodal_projector_bias"`
ProjectorHiddenAct string `json:"projector_hidden_act"`
}
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.TextModel.VocabSize
// Text configuration
kv["mistral3.block_count"] = p.TextModel.NumHiddenLayers
kv["mistral3.context_length"] = p.TextModel.MaxPositionEmbeddings
kv["mistral3.embedding_length"] = p.TextModel.HiddenSize
kv["mistral3.feed_forward_length"] = p.TextModel.IntermediateSize
kv["mistral3.attention.head_count"] = p.TextModel.NumAttentionHeads
kv["mistral3.attention.head_count_kv"] = p.TextModel.NumKeyValueHeads
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta
// Vision configuration
kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers
kv["mistral3.vision.embedding_length"] = p.VisionModel.HiddenSize
kv["mistral3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
kv["mistral3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
kv["mistral3.vision.attention.key_length"] = p.VisionModel.HeadDim
kv["mistral3.vision.image_size"] = p.VisionModel.ImageSize
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
// kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value
kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta
// Multimodal configuration
kv["mistral3.image_token_index"] = p.ImageTokenIndex
kv["mistral3.spatial_merge_size"] = p.SpatialMergeSize
kv["mistral3.mm.projector_bias"] = p.MultiModalProjectorBias
if p.ProjectorHiddenAct != "" {
kv["mistral3.mm.projector_hidden_act"] = p.ProjectorHiddenAct
}
return kv
}
func (p *mistral3Model) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
for _, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") {
if strings.HasSuffix(t.Name(), ".attn_q.weight") ||
strings.HasSuffix(t.Name(), ".attn_k.weight") {
t.SetRepacker(p.repack)
}
}
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *mistral3Model) Replacements() []string {
return []string{
"language_model.model.norm", "output_norm",
"language_model.model.", "",
"language_model.", "",
"layers", "blk",
"transformer.layers", "blk",
"vision_tower", "v",
"ln_pre", "encoder_norm",
"input_layernorm", "attn_norm",
"post_attention_layernorm", "ffn_norm",
"embed_tokens", "token_embd",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"mlp.down_proj", "ffn_down",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"attention.q_proj", "attn_q",
"attention.k_proj", "attn_k",
"attention.v_proj", "attn_v",
"attention.o_proj", "attn_output",
"attention_norm", "attn_norm",
"feed_forward.gate_proj", "ffn_gate",
"feed_forward.down_proj", "ffn_down",
"feed_forward.up_proj", "ffn_up",
"multi_modal_projector", "mm",
"ffn_norm", "ffn_norm",
"lm_head", "output",
}
}
func (p *mistral3Model) repack(name string, data []float32, shape []uint64) ([]float32, error) {
var dims []int
for _, dim := range shape {
dims = append(dims, int(dim))
}
var heads uint32
if strings.HasSuffix(name, ".attn_q.weight") {
heads = p.TextModel.NumAttentionHeads
} else if strings.HasSuffix(name, ".attn_k.weight") {
heads = cmp.Or(p.TextModel.NumKeyValueHeads, p.TextModel.NumAttentionHeads)
} else {
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
}
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
return nil, err
}
if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := n.Reshape(dims...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 1)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}

View File

@@ -118,5 +118,6 @@ func (p *phi3Model) Replacements() []string {
type ropeFactor []float32
func (r ropeFactor) WriteTo(w io.Writer) (int64, error) {
return 0, binary.Write(w, binary.LittleEndian, r)
err := binary.Write(w, binary.LittleEndian, r)
return 0, err
}

View File

@@ -2,6 +2,7 @@ package convert
import "github.com/ollama/ollama/fs/ggml"
type qwen2Model struct {
ModelParameters
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`

View File

@@ -11,6 +11,7 @@ import (
"io"
"io/fs"
"log/slog"
"math"
"os"
"path/filepath"
"slices"
@@ -47,7 +48,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
}
t.Cleanup(func() { r.Close() })
m, _, err := ggml.Decode(r, -1)
m, _, err := ggml.Decode(r, math.MaxInt)
if err != nil {
t.Fatal(err)
}
@@ -331,7 +332,7 @@ func TestConvertAdapter(t *testing.T) {
}
defer r.Close()
m, _, err := ggml.Decode(r, -1)
m, _, err := ggml.Decode(r, math.MaxInt)
if err != nil {
t.Fatal(err)
}

View File

@@ -11,15 +11,14 @@ type Tensor interface {
Name() string
Shape() []uint64
Kind() uint32
SetRepacker(Repacker)
SetRepacker(repacker)
WriteTo(io.Writer) (int64, error)
Clone() Tensor
}
type tensorBase struct {
name string
shape []uint64
repacker Repacker
name string
shape []uint64
repacker
}
func (t tensorBase) Name() string {
@@ -37,8 +36,7 @@ const (
func (t tensorBase) Kind() uint32 {
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
t.name == "token_types.weight" ||
t.name == "v.positional_embedding_vlm" {
t.name == "token_types.weight" {
// these tensors are always F32
return 0
}
@@ -53,18 +51,21 @@ func (t tensorBase) Kind() uint32 {
}
}
func (t *tensorBase) SetRepacker(fn Repacker) {
func (t *tensorBase) SetRepacker(fn repacker) {
t.repacker = fn
}
type Repacker func(string, []float32, []uint64) ([]float32, error)
type repacker func(string, []float32, []uint64) ([]float32, error)
func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) {
patterns := []struct {
Pattern string
Func func(fs.FS, *strings.Replacer, ...string) ([]Tensor, error)
}{
{"*.safetensors", parseSafetensors},
{"model-*-of-*.safetensors", parseSafetensors},
{"model.safetensors", parseSafetensors},
{"adapters.safetensors", parseSafetensors},
{"adapter_model.safetensors", parseSafetensors},
{"pytorch_model-*-of-*.bin", parseTorch},
{"pytorch_model.bin", parseTorch},
{"consolidated.*.pth", parseTorch},

View File

@@ -94,21 +94,6 @@ type safetensor struct {
*tensorBase
}
func (st safetensor) Clone() Tensor {
return &safetensor{
fs: st.fs,
path: st.path,
dtype: st.dtype,
offset: st.offset,
size: st.size,
tensorBase: &tensorBase{
name: st.name,
repacker: st.repacker,
shape: slices.Clone(st.shape),
},
}
}
func (st safetensor) WriteTo(w io.Writer) (int64, error) {
f, err := st.fs.Open(st.path)
if err != nil {

View File

@@ -43,17 +43,6 @@ type torch struct {
*tensorBase
}
func (t torch) Clone() Tensor {
return torch{
storage: t.storage,
tensorBase: &tensorBase{
name: t.name,
shape: t.shape,
repacker: t.repacker,
},
}
}
func (pt torch) WriteTo(w io.Writer) (int64, error) {
return 0, nil
}

View File

@@ -1360,7 +1360,7 @@ func file_sentencepiece_model_proto_rawDescGZIP() []byte {
var file_sentencepiece_model_proto_enumTypes = make([]protoimpl.EnumInfo, 2)
var file_sentencepiece_model_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
var file_sentencepiece_model_proto_goTypes = []any{
var file_sentencepiece_model_proto_goTypes = []interface{}{
(TrainerSpec_ModelType)(0), // 0: sentencepiece.TrainerSpec.ModelType
(ModelProto_SentencePiece_Type)(0), // 1: sentencepiece.ModelProto.SentencePiece.Type
(*TrainerSpec)(nil), // 2: sentencepiece.TrainerSpec
@@ -1392,7 +1392,7 @@ func file_sentencepiece_model_proto_init() {
return
}
if !protoimpl.UnsafeEnabled {
file_sentencepiece_model_proto_msgTypes[0].Exporter = func(v any, i int) any {
file_sentencepiece_model_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*TrainerSpec); i {
case 0:
return &v.state
@@ -1406,7 +1406,7 @@ func file_sentencepiece_model_proto_init() {
return nil
}
}
file_sentencepiece_model_proto_msgTypes[1].Exporter = func(v any, i int) any {
file_sentencepiece_model_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*NormalizerSpec); i {
case 0:
return &v.state
@@ -1420,7 +1420,7 @@ func file_sentencepiece_model_proto_init() {
return nil
}
}
file_sentencepiece_model_proto_msgTypes[2].Exporter = func(v any, i int) any {
file_sentencepiece_model_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SelfTestData); i {
case 0:
return &v.state
@@ -1434,7 +1434,7 @@ func file_sentencepiece_model_proto_init() {
return nil
}
}
file_sentencepiece_model_proto_msgTypes[3].Exporter = func(v any, i int) any {
file_sentencepiece_model_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ModelProto); i {
case 0:
return &v.state
@@ -1448,7 +1448,7 @@ func file_sentencepiece_model_proto_init() {
return nil
}
}
file_sentencepiece_model_proto_msgTypes[4].Exporter = func(v any, i int) any {
file_sentencepiece_model_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SelfTestData_Sample); i {
case 0:
return &v.state
@@ -1460,7 +1460,7 @@ func file_sentencepiece_model_proto_init() {
return nil
}
}
file_sentencepiece_model_proto_msgTypes[5].Exporter = func(v any, i int) any {
file_sentencepiece_model_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ModelProto_SentencePiece); i {
case 0:
return &v.state

View File

@@ -6,9 +6,7 @@ import (
"errors"
"fmt"
"io/fs"
"log/slog"
"os"
"reflect"
"slices"
"google.golang.org/protobuf/proto"
@@ -17,8 +15,6 @@ import (
)
func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
slog.Debug("using spm vocabulary")
ast, err := parseAdditionalSpecialTokens(fsys)
if err != nil {
return nil, err
@@ -47,19 +43,10 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
v.Types = append(v.Types, int32(t))
default:
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
// temporary fix to handle gemma3 broken configs
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>"}, piece.GetPiece()) {
if slices.Contains(ast, piece.GetPiece()) {
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
}
for _, t := range ast {
if t.Content == piece.GetPiece() {
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
break
}
}
v.Types = append(v.Types, tt)
}
}
@@ -91,16 +78,10 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
return cmp.Compare(i.id, j.id)
})
for _, t := range ts {
if t.id < len(v.Tokens) {
if v.Tokens[t.id] == t.content {
slog.Warn("tokenizer", "duplicate token", t.content, "id", t.id)
continue
}
return nil, fmt.Errorf("token mismatch: %s != %s at pos [%d]", t.content, v.Tokens[t.id], t.id)
}
if t.id != len(v.Tokens) {
return nil, fmt.Errorf("invalid token id: [%d] as pos [%d]", t.id, len(v.Tokens))
n := len(v.Tokens)
for i, t := range ts {
if t.id != i+n {
return nil, fmt.Errorf("invalid token id: %d", t.id)
}
v.Tokens = append(v.Tokens, t.content)
@@ -111,15 +92,7 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
return &v, nil
}
type specialToken struct {
Content string `json:"content"`
Lstrip bool `json:"lstrip"`
Normalized bool `json:"normalized"`
Rstrip bool `json:"rstrip"`
SingleWord bool `json:"single_word"`
}
func parseAdditionalSpecialTokens(fsys fs.FS) ([]specialToken, error) {
func parseAdditionalSpecialTokens(fsys fs.FS) ([]string, error) {
f, err := fsys.Open("special_tokens_map.json")
if errors.Is(err, os.ErrNotExist) {
return nil, nil
@@ -129,43 +102,12 @@ func parseAdditionalSpecialTokens(fsys fs.FS) ([]specialToken, error) {
defer f.Close()
var m struct {
AdditionalSpecialTokens any `json:"additional_special_tokens"`
AdditionalSpecialTokens []string `json:"additional_special_tokens"`
}
if err := json.NewDecoder(f).Decode(&m); err != nil {
return nil, err
}
var ast []specialToken
switch st := m.AdditionalSpecialTokens.(type) {
case []string:
for _, s := range st {
ast = append(ast, specialToken{Content: s})
}
case []any:
for _, s := range st {
// marshal and unmarshal the object to get the special token
tMap := s.(map[string]any)
data, err := json.Marshal(tMap)
if err != nil {
return nil, err
}
var token specialToken
err = json.Unmarshal(data, &token)
if err != nil {
return nil, err
}
ast = append(ast, token)
}
default:
slog.Warn("special token", "unknown token", reflect.TypeOf(st))
}
slog.Debug("spm tokenizer", "additional tokens", ast)
return ast, nil
return m.AdditionalSpecialTokens, nil
}

View File

@@ -9,6 +9,8 @@ import (
"path/filepath"
"runtime"
"strings"
"github.com/ollama/ollama/envconfig"
)
// Determine if the given ROCm lib directory is usable by checking for existence of some glob patterns
@@ -39,10 +41,13 @@ func commonAMDValidateLibDir() (string, error) {
// Favor our bundled version
// Installer payload location if we're running the installed binary
rocmTargetDir := filepath.Join(LibOllamaPath, "rocm")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
exe, err := os.Executable()
if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), envconfig.LibRelativeToExe(), "lib", "ollama")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
}
}
// Prefer explicit HIP env var

View File

@@ -77,7 +77,8 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
gfxOverride := envconfig.HsaOverrideGfxVersion()
var supported []string
var libDir string
depPaths := LibraryDirs()
libDir := ""
// The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract
// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
@@ -352,8 +353,9 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
})
return nil, err
}
depPaths = append(depPaths, libDir)
}
gpuInfo.DependencyPath = []string{libDir}
gpuInfo.DependencyPath = depPaths
if gfxOverride == "" {
// Only load supported list once

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"log/slog"
"os"
"path/filepath"
"slices"
"strconv"
@@ -49,13 +50,14 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
slog.Info(err.Error())
return nil, err
}
depPaths := LibraryDirs()
libDir, err := AMDValidateLibDir()
if err != nil {
err = fmt.Errorf("unable to verify rocm library: %w", err)
slog.Warn(err.Error())
return nil, err
}
depPaths = append(depPaths, libDir)
var supported []string
gfxOverride := envconfig.HsaOverrideGfxVersion()
@@ -111,7 +113,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
UnreliableFreeMemory: true,
ID: strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices
DependencyPath: []string{libDir},
DependencyPath: depPaths,
MinimumMemory: rocmMinimumMemory,
Name: name,
Compute: gfx,
@@ -162,7 +164,9 @@ func AMDValidateLibDir() (string, error) {
}
// Installer payload (if we're running from some other location)
rocmTargetDir := filepath.Join(LibOllamaPath, "rocm")
localAppData := os.Getenv("LOCALAPPDATA")
appDir := filepath.Join(localAppData, "Programs", "Ollama")
rocmTargetDir := filepath.Join(appDir, envconfig.LibRelativeToExe(), "lib", "ollama")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ollama installed ROCm at " + rocmTargetDir)
return rocmTargetDir, nil

View File

@@ -12,7 +12,7 @@ func IsNUMA() bool {
// numa support in llama.cpp is linux only
return false
}
ids := map[string]any{}
ids := map[string]interface{}{}
packageIds, _ := filepath.Glob("/sys/devices/system/cpu/cpu*/topology/physical_package_id")
for _, packageId := range packageIds {
id, err := os.ReadFile(packageId)

View File

@@ -57,8 +57,7 @@ func cudaVariant(gpuInfo CudaGPUInfo) string {
}
}
// driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers
if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
if gpuInfo.computeMajor < 6 || gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
return "v11"
}
return "v12"

View File

@@ -23,6 +23,7 @@ import (
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/runners"
)
type cudaHandles struct {
@@ -100,7 +101,15 @@ func initCudaHandles() *cudaHandles {
// Aligned with driver, we can't carry as payloads
nvcudaMgmtPatterns := NvcudaGlobs
cudartMgmtPatterns = append(cudartMgmtPatterns, filepath.Join(LibOllamaPath, "cuda_v*", CudartMgmtName))
if runtime.GOOS == "windows" {
localAppData := os.Getenv("LOCALAPPDATA")
cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", CudartMgmtName)}
}
libDirs := LibraryDirs()
for _, d := range libDirs {
cudartMgmtPatterns = append(cudartMgmtPatterns, filepath.Join(d, CudartMgmtName))
}
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartGlobs...)
if len(NvmlGlobs) > 0 {
@@ -231,7 +240,7 @@ func GetGPUInfo() GpuInfoList {
if err != nil {
slog.Warn("error looking up system memory", "error", err)
}
depPaths := LibraryDirs()
details, err := GetCPUDetails()
if err != nil {
slog.Warn("failed to lookup CPU details", "error", err)
@@ -239,9 +248,11 @@ func GetGPUInfo() GpuInfoList {
cpus = []CPUInfo{
{
GpuInfo: GpuInfo{
memInfo: mem,
Library: "cpu",
ID: "0",
memInfo: mem,
Library: "cpu",
Variant: runners.GetCPUCapability().String(),
ID: "0",
DependencyPath: depPaths,
},
CPUs: details,
},
@@ -283,13 +294,17 @@ func GetGPUInfo() GpuInfoList {
gpuInfo.DriverMajor = driverMajor
gpuInfo.DriverMinor = driverMinor
variant := cudaVariant(gpuInfo)
// Start with our bundled libraries
if variant != "" {
variantPath := filepath.Join(LibOllamaPath, "cuda_"+variant)
if _, err := os.Stat(variantPath); err == nil {
// Put the variant directory first in the search path to avoid runtime linking to the wrong library
gpuInfo.DependencyPath = append([]string{variantPath}, gpuInfo.DependencyPath...)
if depPaths != nil {
gpuInfo.DependencyPath = depPaths
// Check for variant specific directory
if variant != "" {
for _, d := range depPaths {
if _, err := os.Stat(filepath.Join(d, "cuda_"+variant)); err == nil {
// Put the variant directory first in the search path to avoid runtime linking to the wrong library
gpuInfo.DependencyPath = append([]string{filepath.Join(d, "cuda_"+variant)}, gpuInfo.DependencyPath...)
break
}
}
}
}
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
@@ -361,7 +376,7 @@ func GetGPUInfo() GpuInfoList {
gpuInfo.FreeMemory = uint64(memInfo.free)
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
gpuInfo.DependencyPath = []string{LibOllamaPath}
gpuInfo.DependencyPath = depPaths
oneapiGPUs = append(oneapiGPUs, gpuInfo)
}
}
@@ -497,30 +512,33 @@ func GetGPUInfo() GpuInfoList {
func FindGPULibs(baseLibName string, defaultPatterns []string) []string {
// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
var ldPaths []string
gpuLibPaths := []string{}
slog.Debug("Searching for GPU library", "name", baseLibName)
// search our bundled libraries first
patterns := []string{filepath.Join(LibOllamaPath, baseLibName)}
var ldPaths []string
switch runtime.GOOS {
case "windows":
ldPaths = strings.Split(os.Getenv("PATH"), string(os.PathListSeparator))
case "linux":
ldPaths = strings.Split(os.Getenv("LD_LIBRARY_PATH"), string(os.PathListSeparator))
// Start with our bundled libraries
patterns := []string{}
for _, d := range LibraryDirs() {
patterns = append(patterns, filepath.Join(d, baseLibName))
}
// then search the system's LD_LIBRARY_PATH
for _, p := range ldPaths {
p, err := filepath.Abs(p)
switch runtime.GOOS {
case "windows":
ldPaths = strings.Split(os.Getenv("PATH"), ";")
case "linux":
ldPaths = strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
default:
return gpuLibPaths
}
// Then with whatever we find in the PATH/LD_LIBRARY_PATH
for _, ldPath := range ldPaths {
d, err := filepath.Abs(ldPath)
if err != nil {
continue
}
patterns = append(patterns, filepath.Join(p, baseLibName))
patterns = append(patterns, filepath.Join(d, baseLibName))
}
// finally, search the default patterns provided by the caller
patterns = append(patterns, defaultPatterns...)
slog.Debug("gpu library search", "globs", patterns)
for _, pattern := range patterns {
@@ -697,6 +715,23 @@ func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
}
}
func LibraryDirs() []string {
// dependencies can exist wherever we found the runners (e.g. build tree for developers) and relative to the executable
// This can be simplified once we no longer carry runners as payloads
exe, err := os.Executable()
if err != nil {
slog.Warn("failed to lookup executable path", "error", err)
return nil
}
lib := filepath.Join(filepath.Dir(exe), envconfig.LibRelativeToExe(), "lib", "ollama")
if _, err := os.Stat(lib); err != nil {
return nil
}
return []string{lib}
}
func GetSystemInfo() SystemInfo {
gpus := GetGPUInfo()
gpuMutex.Lock()

View File

@@ -15,6 +15,7 @@ import (
"syscall"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/runners"
)
const (
@@ -27,6 +28,7 @@ func GetGPUInfo() GpuInfoList {
return []GpuInfo{
{
Library: "cpu",
Variant: runners.GetCPUCapability().String(),
memInfo: mem,
},
}
@@ -49,6 +51,7 @@ func GetCPUInfo() GpuInfoList {
return []GpuInfo{
{
Library: "cpu",
Variant: runners.GetCPUCapability().String(),
memInfo: mem,
},
}

View File

@@ -111,7 +111,6 @@ func GetCPUDetails() ([]CPU, error) {
if err != nil {
return nil, err
}
defer file.Close()
return linuxCPUDetails(file)
}
@@ -169,11 +168,13 @@ func linuxCPUDetails(file io.Reader) ([]CPU, error) {
for id, s := range socketByID {
s.CoreCount = len(coreBySocket[id])
s.ThreadCount = 0
for _, tc := range threadsByCoreBySocket[id] {
s.ThreadCount += tc
}
// This only works if HT is enabled, consider a more reliable model, maybe cache size comparisons?
efficiencyCoreCount := 0
for _, threads := range threadsByCoreBySocket[id] {
s.ThreadCount += threads
if threads == 1 {
efficiencyCoreCount++
}

View File

@@ -1,56 +0,0 @@
package discover
import (
"os"
"path/filepath"
"runtime"
)
// LibPath is a path to lookup dynamic libraries
// in development it's usually 'build/lib/ollama'
// in distribution builds it's 'lib/ollama' on Windows
// '../lib/ollama' on Linux and the executable's directory on macOS
// note: distribution builds, additional GPU-specific libraries are
// found in subdirectories of the returned path, such as
// 'cuda_v11', 'cuda_v12', 'rocm', etc.
var LibOllamaPath string = func() string {
exe, err := os.Executable()
if err != nil {
return ""
}
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
var libPath string
switch runtime.GOOS {
case "windows":
libPath = filepath.Join(filepath.Dir(exe), "lib", "ollama")
case "linux":
libPath = filepath.Join(filepath.Dir(exe), "..", "lib", "ollama")
case "darwin":
libPath = filepath.Dir(exe)
}
cwd, err := os.Getwd()
if err != nil {
return ""
}
paths := []string{
libPath,
// build paths for development
filepath.Join(filepath.Dir(exe), "build", "lib", "ollama"),
filepath.Join(cwd, "build", "lib", "ollama"),
}
for _, p := range paths {
if _, err := os.Stat(p); err == nil {
return p
}
}
return filepath.Dir(exe)
}()

View File

@@ -5,6 +5,7 @@ import (
"log/slog"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/runners"
)
type memInfo struct {
@@ -106,7 +107,7 @@ func (l GpuInfoList) ByLibrary() []GpuInfoList {
for _, info := range l {
found := false
requested := info.Library
if info.Variant != "" {
if info.Variant != runners.CPUCapabilityNone.String() {
requested += "_" + info.Variant
}
for i, lib := range libs {

View File

@@ -31,7 +31,7 @@ Certain endpoints stream responses as JSON objects. Streaming can be disabled by
## Generate a completion
```
```shell
POST /api/generate
```
@@ -173,7 +173,7 @@ curl http://localhost:11434/api/generate -d '{
##### Response
```json5
```json
{
"model": "codellama:code",
"created_at": "2024-07-22T20:47:51.147561Z",
@@ -306,7 +306,7 @@ curl http://localhost:11434/api/generate -d '{
#### Response
```json
```
{
"model": "llava",
"created_at": "2023-11-03T15:36:02.583064Z",
@@ -485,7 +485,7 @@ A single JSON object is returned:
## Generate a chat completion
```
```shell
POST /api/chat
```
@@ -495,14 +495,14 @@ Generate the next message in a chat with a provided model. This is a streaming e
- `model`: (required) the [model name](#model-names)
- `messages`: the messages of the chat, this can be used to keep a chat memory
- `tools`: list of tools in JSON for the model to use if supported
- `tools`: tools for the model to use if supported. Requires `stream` to be set to `false`
The `message` object has the following fields:
- `role`: the role of the message, either `system`, `user`, `assistant`, or `tool`
- `content`: the content of the message
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
- `tool_calls` (optional): a list of tools in JSON that the model wants to use
- `tool_calls` (optional): a list of tools the model wants to use
Advanced parameters (optional):
@@ -558,10 +558,6 @@ Final response:
{
"model": "llama3.2",
"created_at": "2023-08-04T19:22:45.499127Z",
"message": {
"role": "assistant",
"content": ""
},
"done": true,
"total_duration": 4883583458,
"load_duration": 1334875,
@@ -799,7 +795,7 @@ curl http://localhost:11434/api/chat -d '{
##### Request
```shell
```
curl http://localhost:11434/api/chat -d '{
"model": "llama3.2",
"messages": [
@@ -874,7 +870,7 @@ If the messages array is empty, the model will be loaded into memory.
##### Request
```shell
```
curl http://localhost:11434/api/chat -d '{
"model": "llama3.2",
"messages": []
@@ -882,7 +878,6 @@ curl http://localhost:11434/api/chat -d '{
```
##### Response
```json
{
"model": "llama3.2",
@@ -902,7 +897,7 @@ If the messages array is empty and the `keep_alive` parameter is set to `0`, a m
##### Request
```shell
```
curl http://localhost:11434/api/chat -d '{
"model": "llama3.2",
"messages": [],
@@ -929,7 +924,7 @@ A single JSON object is returned:
## Create a Model
```
```shell
POST /api/create
```
@@ -1025,7 +1020,7 @@ curl http://localhost:11434/api/create -d '{
A stream of JSON objects is returned:
```json
```
{"status":"quantizing F16 model to Q4_K_M"}
{"status":"creating new layer sha256:667b0c1932bc6ffc593ed1d03f895bf2dc8dc6df21db3042284a6f4416b06a29"}
{"status":"using existing layer sha256:11ce4ee3e170f6adebac9a991c22e22ab3f8530e154ee669954c4bc73061c258"}
@@ -1056,7 +1051,7 @@ curl http://localhost:11434/api/create -d '{
A stream of JSON objects is returned:
```json
```
{"status":"parsing GGUF"}
{"status":"using existing layer sha256:432f310a77f4650a88d0fd59ecdd7cebed8d684bafea53cbff0473542964f0c3"}
{"status":"writing manifest"}
@@ -1123,7 +1118,7 @@ Return 200 OK if the blob exists, 404 Not Found if it does not.
## Push a Blob
```
```shell
POST /api/blobs/:digest
```
@@ -1147,7 +1142,7 @@ Return 201 Created if the blob was successfully created, 400 Bad Request if the
## List Local Models
```
```shell
GET /api/tags
```
@@ -1200,7 +1195,7 @@ A single JSON object will be returned.
## Show Model Information
```
```shell
POST /api/show
```
@@ -1217,13 +1212,13 @@ Show information about a model including details, modelfile, template, parameter
```shell
curl http://localhost:11434/api/show -d '{
"model": "llava"
"model": "llama3.2"
}'
```
#### Response
```json5
```json
{
"modelfile": "# Modelfile generated by \"ollama show\"\n# To build a new Modelfile based on this one, replace the FROM line with:\n# FROM llava:latest\n\nFROM /Users/matt/.ollama/models/blobs/sha256:200765e1283640ffbd013184bf496e261032fa75b99498a9613be4e94d63ad52\nTEMPLATE \"\"\"{{ .System }}\nUSER: {{ .Prompt }}\nASSISTANT: \"\"\"\nPARAMETER num_ctx 4096\nPARAMETER stop \"\u003c/s\u003e\"\nPARAMETER stop \"USER:\"\nPARAMETER stop \"ASSISTANT:\"",
"parameters": "num_keep 24\nstop \"<|start_header_id|>\"\nstop \"<|end_header_id|>\"\nstop \"<|eot_id|>\"",
@@ -1260,17 +1255,13 @@ curl http://localhost:11434/api/show -d '{
"tokenizer.ggml.pre": "llama-bpe",
"tokenizer.ggml.token_type": [], // populates if `verbose=true`
"tokenizer.ggml.tokens": [] // populates if `verbose=true`
},
"capabilities": [
"completion",
"vision"
],
}
}
```
## Copy a Model
```
```shell
POST /api/copy
```
@@ -1293,7 +1284,7 @@ Returns a 200 OK if successful, or a 404 Not Found if the source model doesn't e
## Delete a Model
```
```shell
DELETE /api/delete
```
@@ -1319,7 +1310,7 @@ Returns a 200 OK if successful, 404 Not Found if the model to be deleted doesn't
## Pull a Model
```
```shell
POST /api/pull
```
@@ -1391,7 +1382,7 @@ if `stream` is set to false, then the response is a single JSON object:
## Push a Model
```
```shell
POST /api/push
```
@@ -1456,7 +1447,7 @@ If `stream` is set to `false`, then the response is a single JSON object:
## Generate Embeddings
```
```shell
POST /api/embed
```
@@ -1524,7 +1515,7 @@ curl http://localhost:11434/api/embed -d '{
```
## List Running Models
```
```shell
GET /api/ps
```
@@ -1571,7 +1562,7 @@ A single JSON object will be returned.
> Note: this endpoint has been superseded by `/api/embed`
```
```shell
POST /api/embeddings
```
@@ -1611,7 +1602,7 @@ curl http://localhost:11434/api/embeddings -d '{
## Version
```
```shell
GET /api/version
```

View File

@@ -1,59 +0,0 @@
# Benchmark
Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes.
## When to use
Run these benchmarks when:
- Making changes to the model inference engine
- Modifying model loading/unloading logic
- Changing prompt processing or token generation code
- Implementing a new model architecture
- Testing performance across different hardware setups
## Prerequisites
- Ollama server running locally with `ollama serve` on `127.0.0.1:11434`
## Usage and Examples
>[!NOTE]
>All commands must be run from the root directory of the Ollama project.
Basic syntax:
```bash
go test -bench=. ./benchmark/... -m $MODEL_NAME
```
Required flags:
- `-bench=.`: Run all benchmarks
- `-m`: Model name to benchmark
Optional flags:
- `-count N`: Number of times to run the benchmark (useful for statistical analysis)
- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes)
Common usage patterns:
Single benchmark run with a model specified:
```bash
go test -bench=. ./benchmark/... -m llama3.3
```
## Output metrics
The benchmark reports several key metrics:
- `gen_tok/s`: Generated tokens per second
- `prompt_tok/s`: Prompt processing tokens per second
- `ttft_ms`: Time to first token in milliseconds
- `load_ms`: Model load time in milliseconds
- `gen_tokens`: Total tokens generated
- `prompt_tokens`: Total prompt tokens processed
Each benchmark runs two scenarios:
- Cold start: Model is loaded from disk for each test
- Warm start: Model is pre-loaded in memory
Three prompt lengths are tested for each scenario:
- Short prompt (100 tokens)
- Medium prompt (500 tokens)
- Long prompt (1000 tokens)

View File

@@ -1,159 +1,165 @@
# Development
Install prerequisites:
Install required tools:
- [Go](https://go.dev/doc/install)
- C/C++ Compiler e.g. Clang on macOS, [TDM-GCC](https://github.com/jmeubank/tdm-gcc/releases/latest) (Windows amd64) or [llvm-mingw](https://github.com/mstorsjo/llvm-mingw) (Windows arm64), GCC/Clang on Linux.
- go version 1.22 or higher
- OS specific C/C++ compiler (see below)
- GNU Make
Then build and run Ollama from the root directory of the repository:
```shell
go run . serve
## Overview
Ollama uses a mix of Go and C/C++ code to interface with GPUs. The C/C++ code is compiled with both CGO and GPU library specific compilers. A set of GNU Makefiles are used to compile the project. GPU Libraries are auto-detected based on the typical environment variables used by the respective libraries, but can be overridden if necessary. The default make target will build the runners and primary Go Ollama application that will run within the repo directory. Throughout the examples below `-j 5` is suggested for 5 parallel jobs to speed up the build. You can adjust the job count based on your CPU Core count to reduce build times. If you want to relocate the built binaries, use the `dist` target and recursively copy the files in `./dist/$OS-$ARCH/` to your desired location. To learn more about the other make targets use `make help`
Once you have built the GPU/CPU runners, you can compile the main application with `go build .`
### MacOS
[Download Go](https://go.dev/dl/)
```bash
make -j 5
```
## macOS (Apple Silicon)
Now you can run `ollama`:
macOS Apple Silicon supports Metal which is built-in to the Ollama binary. No additional steps are required.
## macOS (Intel)
Install prerequisites:
- [CMake](https://cmake.org/download/) or `brew install cmake`
Then, configure and build the project:
```shell
cmake -B build
cmake --build build
```bash
./ollama
```
Lastly, run Ollama:
#### Xcode 15 warnings
```shell
go run . serve
If you are using Xcode newer than version 14, you may see a warning during `go build` about `ld: warning: ignoring duplicate libraries: '-lobjc'` due to Golang issue https://github.com/golang/go/issues/67799 which can be safely ignored. You can suppress the warning with `export CGO_LDFLAGS="-Wl,-no_warn_duplicate_libraries"`
### Linux
#### Linux CUDA (NVIDIA)
_Your operating system distribution may already have packages for NVIDIA CUDA. Distro packages are often preferable, but instructions are distro-specific. Please consult distro-specific docs for dependencies if available!_
Install `make`, `gcc` and `golang` as well as [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads)
development and runtime packages.
Typically the makefile will auto-detect CUDA, however, if your Linux distro
or installation approach uses alternative paths, you can specify the location by
overriding `CUDA_PATH` to the location of the CUDA toolkit. You can customize
a set of target CUDA architectures by setting `CUDA_ARCHITECTURES` (e.g. `CUDA_ARCHITECTURES=50;60;70`)
```
make -j 5
```
## Windows
If both v11 and v12 tookkits are detected, runners for both major versions will be built by default. You can build just v12 with `make cuda_v12`
Install prerequisites:
#### Older Linux CUDA (NVIDIA)
- [CMake](https://cmake.org/download/)
- [Visual Studio 2022](https://visualstudio.microsoft.com/downloads/) including the Native Desktop Workload
- (Optional) AMD GPU support
- [ROCm](https://rocm.docs.amd.com/en/latest/)
- [Ninja](https://github.com/ninja-build/ninja/releases)
- (Optional) NVIDIA GPU support
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
To support older GPUs with Compute Capability 3.5 or 3.7, you will need to use an older version of the Driver from [Unix Driver Archive](https://www.nvidia.com/en-us/drivers/unix/) (tested with 470) and [CUDA Toolkit Archive](https://developer.nvidia.com/cuda-toolkit-archive) (tested with cuda V11). When you build Ollama, you will need to set two make variable to adjust the minimum compute capability Ollama supports via `make -j 5 CUDA_ARCHITECTURES="35;37;50;52" EXTRA_GOLDFLAGS="\"-X=github.com/ollama/ollama/discover.CudaComputeMajorMin=3\" \"-X=github.com/ollama/ollama/discover.CudaComputeMinorMin=5\""`. To find the Compute Capability of your older GPU, refer to [GPU Compute Capability](https://developer.nvidia.com/cuda-gpus).
Then, configure and build the project:
#### Linux ROCm (AMD)
```shell
cmake -B build
cmake --build build --config Release
_Your operating system distribution may already have packages for AMD ROCm. Distro packages are often preferable, but instructions are distro-specific. Please consult distro-specific docs for dependencies if available!_
Install [ROCm](https://rocm.docs.amd.com/en/latest/) development packages first, as well as `make`, `gcc`, and `golang`.
Typically the build scripts will auto-detect ROCm, however, if your Linux distro
or installation approach uses unusual paths, you can specify the location by
specifying an environment variable `HIP_PATH` to the location of the ROCm
install (typically `/opt/rocm`). You can also customize
the AMD GPU targets by setting HIP_ARCHS (e.g. `HIP_ARCHS=gfx1101;gfx1102`)
```
make -j 5
```
> [!IMPORTANT]
> Building for ROCm requires additional flags:
> ```
> cmake -B build -G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++
> cmake --build build --config Release
> ```
ROCm requires elevated privileges to access the GPU at runtime. On most distros you can add your user account to the `render` group, or run as root.
#### Containerized Linux Build
Lastly, run Ollama:
If you have Docker and buildx available, you can build linux binaries with `./scripts/build_linux.sh` which has the CUDA and ROCm dependencies included. The resulting artifacts are placed in `./dist` and by default the script builds both arm64 and amd64 binaries. If you want to build only amd64, you can build with `PLATFORM=linux/amd64 ./scripts/build_linux.sh`
```shell
go run . serve
### Windows
The following tools are required as a minimal development environment to build CPU inference support.
- Go version 1.22 or higher
- https://go.dev/dl/
- Git
- https://git-scm.com/download/win
- clang with gcc compat and Make. There are multiple options on how to go about installing these tools on Windows. We have verified the following, but others may work as well:
- [MSYS2](https://www.msys2.org/)
- After installing, from an MSYS2 terminal, run `pacman -S mingw-w64-clang-x86_64-gcc-compat mingw-w64-clang-x86_64-clang make` to install the required tools
- Assuming you used the default install prefix for msys2 above, add `C:\msys64\clang64\bin` and `c:\msys64\usr\bin` to your environment variable `PATH` where you will perform the build steps below (e.g. system-wide, account-level, powershell, cmd, etc.)
> [!NOTE]
> Due to bugs in the GCC C++ library for unicode support, Ollama should be built with clang on windows.
```
make -j 5
```
## Windows (ARM)
#### GPU Support
Windows ARM does not support additional acceleration libraries at this time. Do not use cmake, simply `go run` or `go build`.
The GPU tools require the Microsoft native build tools. To build either CUDA or ROCm, you must first install MSVC via Visual Studio:
## Linux
- Make sure to select `Desktop development with C++` as a Workload during the Visual Studio install
- You must complete the Visual Studio install and run it once **BEFORE** installing CUDA or ROCm for the tools to properly register
- Add the location of the **64 bit (x64)** compiler (`cl.exe`) to your `PATH`
- Note: the default Developer Shell may configure the 32 bit (x86) compiler which will lead to build failures. Ollama requires a 64 bit toolchain.
Install prerequisites:
#### Windows CUDA (NVIDIA)
- [CMake](https://cmake.org/download/) or `sudo apt install cmake` or `sudo dnf install cmake`
- (Optional) AMD GPU support
- [ROCm](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/install/quick-start.html)
- (Optional) NVIDIA GPU support
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads)
In addition to the common Windows development tools and MSVC described above:
> [!IMPORTANT]
> Ensure prerequisites are in `PATH` before running CMake.
- [NVIDIA CUDA](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html)
#### Windows ROCm (AMD Radeon)
Then, configure and build the project:
In addition to the common Windows development tools and MSVC described above:
```shell
cmake -B build
cmake --build build
- [AMD HIP](https://www.amd.com/en/developer/resources/rocm-hub/hip-sdk.html)
#### Windows arm64
The default `Developer PowerShell for VS 2022` may default to x86 which is not what you want. To ensure you get an arm64 development environment, start a plain PowerShell terminal and run:
```powershell
import-module 'C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -Arch arm64 -vsinstallpath 'C:\\Program Files\\Microsoft Visual Studio\\2022\\Community' -skipautomaticlocation
```
Lastly, run Ollama:
You can confirm with `write-host $env:VSCMD_ARG_TGT_ARCH`
```shell
go run . serve
Follow the instructions at https://www.msys2.org/wiki/arm64/ to set up an arm64 msys2 environment. Ollama requires gcc and mingw32-make to compile, which is not currently available on Windows arm64, but a gcc compatibility adapter is available via `mingw-w64-clang-aarch64-gcc-compat`. At a minimum you will need to install the following:
```
pacman -S mingw-w64-clang-aarch64-clang mingw-w64-clang-aarch64-gcc-compat mingw-w64-clang-aarch64-make make
```
## Docker
You will need to ensure your PATH includes go, cmake, gcc and clang mingw32-make to build ollama from source. (typically `C:\msys64\clangarm64\bin\`)
```shell
docker build .
## Advanced CPU Vector Settings
On x86, running `make` will compile several CPU runners which can run on different CPU families. At runtime, Ollama will auto-detect the best variation to load. If GPU libraries are present at build time, Ollama also compiles GPU runners with the `AVX` CPU vector feature enabled. This provides a good performance balance when loading large models that split across GPU and CPU with broad compatibility. Some users may prefer no vector extensions (e.g. older Xeon/Celeron processors, or hypervisors that mask the vector features) while other users may prefer turning on many more vector extensions to further improve performance for split model loads.
To customize the set of CPU vector features enabled for a CPU runner and all GPU runners, use CUSTOM_CPU_FLAGS during the build.
To build without any vector flags:
```
make CUSTOM_CPU_FLAGS=""
```
### ROCm
```shell
docker build --build-arg FLAVOR=rocm .
To build with both AVX and AVX2:
```
make CUSTOM_CPU_FLAGS=avx,avx2
```
## Running tests
To build with AVX512 features turned on:
To run tests, use `go test`:
```shell
go test ./...
```
make CUSTOM_CPU_FLAGS=avx,avx2,avx512,avx512vbmi,avx512vnni,avx512bf16
```
> NOTE: In rare cirumstances, you may nedd to change a package using the new
> "synctest" package in go1.24.
>
> If you do not have the "synctest" package enabled, you will not see build or
> test failures resulting from your change(s), if any, locally, but CI will
> break.
>
> If you see failures in CI, you can either keep pushing changes to see if the
> CI build passes, or you can enable the "synctest" package locally to see the
> failures before pushing.
>
> To enable the "synctest" package for testing, run the following command:
>
> ```shell
> GOEXPERIMENT=synctest go test ./...
> ```
>
> If you wish to enable synctest for all go commands, you can set the
> `GOEXPERIMENT` environment variable in your shell profile or by using:
>
> ```shell
> go env -w GOEXPERIMENT=synctest
> ```
>
> Which will enable the "synctest" package for all go commands without needing
> to set it for all shell sessions.
>
> The synctest package is not required for production builds.
## Library detection
Ollama looks for acceleration libraries in the following paths relative to the `ollama` executable:
* `./lib/ollama` (Windows)
* `../lib/ollama` (Linux)
* `.` (macOS)
* `build/lib/ollama` (for development)
If the libraries are not found, Ollama will not run with any acceleration libraries.
> [!NOTE]
> If you are experimenting with different flags, make sure to do a `make clean` between each change to ensure everything is rebuilt with the new compiler flags

View File

@@ -2,7 +2,7 @@
### CPU only
```shell
```bash
docker run -d -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama
```
@@ -11,46 +11,42 @@ Install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-
#### Install with Apt
1. Configure the repository
```shell
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey \
| sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg
curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list \
| sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' \
| sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
sudo apt-get update
```
```bash
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey \
| sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg
curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list \
| sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' \
| sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
sudo apt-get update
```
2. Install the NVIDIA Container Toolkit packages
```shell
sudo apt-get install -y nvidia-container-toolkit
```
```bash
sudo apt-get install -y nvidia-container-toolkit
```
#### Install with Yum or Dnf
1. Configure the repository
```shell
curl -s -L https://nvidia.github.io/libnvidia-container/stable/rpm/nvidia-container-toolkit.repo \
| sudo tee /etc/yum.repos.d/nvidia-container-toolkit.repo
```
```bash
curl -s -L https://nvidia.github.io/libnvidia-container/stable/rpm/nvidia-container-toolkit.repo \
| sudo tee /etc/yum.repos.d/nvidia-container-toolkit.repo
```
2. Install the NVIDIA Container Toolkit packages
```shell
sudo yum install -y nvidia-container-toolkit
```
```bash
sudo yum install -y nvidia-container-toolkit
```
#### Configure Docker to use Nvidia driver
```shell
```
sudo nvidia-ctk runtime configure --runtime=docker
sudo systemctl restart docker
```
#### Start the container
```shell
```bash
docker run -d --gpus=all -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama
```
@@ -61,7 +57,7 @@ docker run -d --gpus=all -v ollama:/root/.ollama -p 11434:11434 --name ollama ol
To run Ollama using Docker with AMD GPUs, use the `rocm` tag and the following command:
```shell
```
docker run -d --device /dev/kfd --device /dev/dri -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama:rocm
```
@@ -69,7 +65,7 @@ docker run -d --device /dev/kfd --device /dev/dri -v ollama:/root/.ollama -p 114
Now you can run a model:
```shell
```
docker exec -it ollama ollama run llama3.2
```

View File

@@ -20,18 +20,12 @@ Please refer to the [GPU docs](./gpu.md).
## How can I specify the context window size?
By default, Ollama uses a context window size of 4096 tokens, unless you have a single GPU with <= 4 GB of VRAM, in which case it will default to 2048 tokens.
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
```shell
OLLAMA_CONTEXT_LENGTH=8192 ollama serve
```
By default, Ollama uses a context window size of 2048 tokens.
To change this when using `ollama run`, use `/set parameter`:
```shell
/set parameter num_ctx 8192
```
/set parameter num_ctx 4096
```
When using the API, specify the `num_ctx` parameter:
@@ -41,7 +35,7 @@ curl http://localhost:11434/api/generate -d '{
"model": "llama3.2",
"prompt": "Why is the sky blue?",
"options": {
"num_ctx": 8192
"num_ctx": 4096
}
}'
```
@@ -52,15 +46,10 @@ Use the `ollama ps` command to see what models are currently loaded into memory.
```shell
ollama ps
NAME ID SIZE PROCESSOR UNTIL
llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
```
> **Output**:
>
> ```
> NAME ID SIZE PROCESSOR UNTIL
> llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
> ```
The `Processor` column will show which memory the model was loaded in to:
* `100% GPU` means the model was loaded entirely into the GPU
* `100% CPU` means the model was loaded entirely in system memory
@@ -77,7 +66,7 @@ If Ollama is run as a macOS application, environment variables should be set usi
1. For each environment variable, call `launchctl setenv`.
```bash
launchctl setenv OLLAMA_HOST "0.0.0.0:11434"
launchctl setenv OLLAMA_HOST "0.0.0.0"
```
2. Restart Ollama application.
@@ -92,14 +81,14 @@ If Ollama is run as a systemd service, environment variables should be set using
```ini
[Service]
Environment="OLLAMA_HOST=0.0.0.0:11434"
Environment="OLLAMA_HOST=0.0.0.0"
```
3. Save and exit.
4. Reload `systemd` and restart Ollama:
```shell
```bash
systemctl daemon-reload
systemctl restart ollama
```
@@ -193,13 +182,6 @@ cloudflared tunnel --url http://localhost:11434 --http-host-header="localhost:11
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`.
For browser extensions, you'll need to explicitly allow the extension's origin pattern. Set `OLLAMA_ORIGINS` to include `chrome-extension://*`, `moz-extension://*`, and `safari-web-extension://*` if you wish to allow all browser extensions access, or specific extensions as needed:
```
# Allow all Chrome, Firefox, and Safari extensions
OLLAMA_ORIGINS=chrome-extension://*,moz-extension://*,safari-web-extension://* ollama serve
```
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
## Where are models stored?
@@ -239,19 +221,16 @@ properties.
If you are using the API you can preload a model by sending the Ollama server an empty request. This works with both the `/api/generate` and `/api/chat` API endpoints.
To preload the mistral model using the generate endpoint, use:
```shell
curl http://localhost:11434/api/generate -d '{"model": "mistral"}'
```
To use the chat completions endpoint, use:
```shell
curl http://localhost:11434/api/chat -d '{"model": "mistral"}'
```
To preload a model using the CLI, use the command:
```shell
ollama run llama3.2 ""
```
@@ -271,13 +250,11 @@ If you're using the API, use the `keep_alive` parameter with the `/api/generate`
* '0' which will unload the model immediately after generating a response
For example, to preload a model and leave it in memory use:
```shell
curl http://localhost:11434/api/generate -d '{"model": "llama3.2", "keep_alive": -1}'
```
To unload the model and free up memory use:
```shell
curl http://localhost:11434/api/generate -d '{"model": "llama3.2", "keep_alive": 0}'
```

View File

@@ -7,7 +7,7 @@ Check your compute compatibility to see if your card is supported:
| Compute Capability | Family | Cards |
| ------------------ | ------------------- | ----------------------------------------------------------------------------------------------------------- |
| 9.0 | NVIDIA | `H200` `H100` |
| 9.0 | NVIDIA | `H100` |
| 8.9 | GeForce RTX 40xx | `RTX 4090` `RTX 4080 SUPER` `RTX 4080` `RTX 4070 Ti SUPER` `RTX 4070 Ti` `RTX 4070 SUPER` `RTX 4070` `RTX 4060 Ti` `RTX 4060` |
| | NVIDIA Professional | `L4` `L40` `RTX 6000` |
| 8.6 | GeForce RTX 30xx | `RTX 3090 Ti` `RTX 3090` `RTX 3080 Ti` `RTX 3080` `RTX 3070 Ti` `RTX 3070` `RTX 3060 Ti` `RTX 3060` `RTX 3050 Ti` `RTX 3050` |

View File

@@ -20,13 +20,13 @@ Make sure that you use the same base model in the `FROM` command as you used to
Now run `ollama create` from the directory where the `Modelfile` was created:
```shell
```bash
ollama create my-model
```
Lastly, test the model:
```shell
```bash
ollama run my-model
```

View File

@@ -75,7 +75,7 @@ RestartSec=3
Environment="PATH=$PATH"
[Install]
WantedBy=multi-user.target
WantedBy=default.target
```
Then start the service:
@@ -119,7 +119,7 @@ sudo systemctl status ollama
To customize the installation of Ollama, you can edit the systemd service file or the environment variables by running:
```shell
```
sudo systemctl edit ollama
```
@@ -152,7 +152,7 @@ Use `OLLAMA_VERSION` environment variable with the install script to install a s
For example:
```shell
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.5.7 sh
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.3.9 sh
```
## Viewing logs
@@ -186,9 +186,3 @@ sudo rm -r /usr/share/ollama
sudo userdel ollama
sudo groupdel ollama
```
Remove installed libraries:
```shell
sudo rm -rf /usr/local/lib/ollama
```

View File

@@ -28,7 +28,7 @@ A model file is the blueprint to create and share models with Ollama.
The format of the `Modelfile`:
```
```modelfile
# comment
INSTRUCTION arguments
```
@@ -49,7 +49,7 @@ INSTRUCTION arguments
An example of a `Modelfile` creating a mario blueprint:
```
```modelfile
FROM llama3.2
# sets the temperature to 1 [higher is more creative, lower is more coherent]
PARAMETER temperature 1
@@ -67,32 +67,28 @@ To use this:
3. `ollama run choose-a-model-name`
4. Start using the model!
More examples are available in the [examples directory](../examples).
To view the Modelfile of a given model, use the `ollama show --modelfile` command.
```shell
ollama show --modelfile llama3.2
```
```bash
> ollama show --modelfile llama3.2
# Modelfile generated by "ollama show"
# To build a new Modelfile based on this one, replace the FROM line with:
# FROM llama3.2:latest
FROM /Users/pdevine/.ollama/models/blobs/sha256-00e1317cbf74d901080d7100f57580ba8dd8de57203072dc6f668324ba545f29
TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
> **Output**:
>
> ```
> # Modelfile generated by "ollama show"
> # To build a new Modelfile based on this one, replace the FROM line with:
> # FROM llama3.2:latest
> FROM /Users/pdevine/.ollama/models/blobs/sha256-00e1317cbf74d901080d7100f57580ba8dd8de57203072dc6f668324ba545f29
> TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
>
> {{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
>
> {{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
>
> {{ .Response }}<|eot_id|>"""
> PARAMETER stop "<|start_header_id|>"
> PARAMETER stop "<|end_header_id|>"
> PARAMETER stop "<|eot_id|>"
> PARAMETER stop "<|reserved_special_token"
> ```
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
{{ .Response }}<|eot_id|>"""
PARAMETER stop "<|start_header_id|>"
PARAMETER stop "<|end_header_id|>"
PARAMETER stop "<|eot_id|>"
PARAMETER stop "<|reserved_special_token"
```
## Instructions
@@ -100,13 +96,13 @@ ollama show --modelfile llama3.2
The `FROM` instruction defines the base model to use when creating a model.
```
```modelfile
FROM <model name>:<tag>
```
#### Build from existing model
```
```modelfile
FROM llama3.2
```
@@ -117,7 +113,7 @@ Additional models can be found at:
#### Build from a Safetensors model
```
```modelfile
FROM <model directory>
```
@@ -131,7 +127,7 @@ Currently supported model architectures:
#### Build from a GGUF file
```
```modelfile
FROM ./ollama-model.gguf
```
@@ -142,7 +138,7 @@ The GGUF file location should be specified as an absolute path or relative to th
The `PARAMETER` instruction defines a parameter that can be set when the model is run.
```
```modelfile
PARAMETER <parameter> <parametervalue>
```
@@ -159,6 +155,7 @@ PARAMETER <parameter> <parametervalue>
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
| seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | int | seed 42 |
| stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate `stop` parameters in a modelfile. | string | stop "AI assistant:" |
| tfs_z | Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1) | float | tfs_z 1 |
| num_predict | Maximum number of tokens to predict when generating text. (Default: -1, infinite generation) | int | num_predict 42 |
| top_k | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | top_k 40 |
| top_p | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | float | top_p 0.9 |
@@ -189,7 +186,7 @@ TEMPLATE """{{ if .System }}<|im_start|>system
The `SYSTEM` instruction specifies the system message to be used in the template, if applicable.
```
```modelfile
SYSTEM """<system message>"""
```
@@ -199,7 +196,7 @@ The `ADAPTER` instruction specifies a fine tuned LoRA adapter that should apply
#### Safetensor adapter
```
```modelfile
ADAPTER <path to safetensor adapter>
```
@@ -210,7 +207,7 @@ Currently supported Safetensor adapters:
#### GGUF adapter
```
```modelfile
ADAPTER ./ollama-lora.gguf
```
@@ -218,7 +215,7 @@ ADAPTER ./ollama-lora.gguf
The `LICENSE` instruction allows you to specify the legal license under which the model used with this Modelfile is shared or distributed.
```
```modelfile
LICENSE """
<license text>
"""
@@ -228,7 +225,7 @@ LICENSE """
The `MESSAGE` instruction allows you to specify a message history for the model to use when responding. Use multiple iterations of the MESSAGE command to build up a conversation which will guide the model to answer in a similar way.
```
```modelfile
MESSAGE <role> <message>
```
@@ -243,7 +240,7 @@ MESSAGE <role> <message>
#### Example conversation
```
```modelfile
MESSAGE user Is Toronto in Canada?
MESSAGE assistant yes
MESSAGE user Is Sacramento in Canada?

View File

@@ -1,7 +1,6 @@
# OpenAI compatibility
> [!NOTE]
> OpenAI compatibility is experimental and is subject to major adjustments including breaking changes. For fully-featured access to the Ollama API, see the Ollama [Python library](https://github.com/ollama/ollama-python), [JavaScript library](https://github.com/ollama/ollama-js) and [REST API](https://github.com/ollama/ollama/blob/main/docs/api.md).
> **Note:** OpenAI compatibility is experimental and is subject to major adjustments including breaking changes. For fully-featured access to the Ollama API, see the Ollama [Python library](https://github.com/ollama/ollama-python), [JavaScript library](https://github.com/ollama/ollama-js) and [REST API](https://github.com/ollama/ollama/blob/main/docs/api.md).
Ollama provides experimental compatibility with parts of the [OpenAI API](https://platform.openai.com/docs/api-reference) to help connect existing applications to Ollama.
@@ -60,10 +59,8 @@ embeddings = client.embeddings.create(
input=["why is the sky blue?", "why is the grass green?"],
)
```
#### Structured outputs
```python
```py
from pydantic import BaseModel
from openai import OpenAI
@@ -147,7 +144,7 @@ const embedding = await openai.embeddings.create({
### `curl`
```shell
``` shell
curl http://localhost:11434/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
@@ -322,7 +319,7 @@ ollama pull llama3.2
For tooling that relies on default OpenAI model names such as `gpt-3.5-turbo`, use `ollama cp` to copy an existing model name to a temporary name:
```shell
```
ollama cp llama3.2 gpt-3.5-turbo
```
@@ -346,7 +343,7 @@ curl http://localhost:11434/v1/chat/completions \
The OpenAI API does not have a way of setting the context size for a model. If you need to change the context size, create a `Modelfile` which looks like:
```
```modelfile
FROM <some model>
PARAMETER num_ctx <context size>
```

View File

@@ -12,7 +12,7 @@ A basic Go template consists of three main parts:
Here's an example of a simple chat template:
```go
```gotmpl
{{- range .Messages }}
{{ .Role }}: {{ .Content }}
{{- end }}
@@ -162,6 +162,6 @@ CodeLlama [7B](https://ollama.com/library/codellama:7b-code) and [13B](https://o
Codestral [22B](https://ollama.com/library/codestral:22b) supports fill-in-middle.
```go
```gotmpl
[SUFFIX]{{ .Suffix }}[PREFIX] {{ .Prompt }}
```

View File

@@ -9,7 +9,7 @@ cat ~/.ollama/logs/server.log
On **Linux** systems with systemd, the logs can be found with this command:
```shell
journalctl -u ollama --no-pager --follow --pager-end
journalctl -u ollama --no-pager
```
When you run Ollama in a **container**, the logs go to stdout/stderr in the container:
@@ -17,7 +17,6 @@ When you run Ollama in a **container**, the logs go to stdout/stderr in the cont
```shell
docker logs <container-name>
```
(Use `docker ps` to find the container name)
If manually running `ollama serve` in a terminal, the logs will be on that terminal.
@@ -26,9 +25,9 @@ When you run Ollama on **Windows**, there are a few different locations. You can
- `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log`
- `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH)
- `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored
- `explorer %TEMP%` where temporary executable files are stored in one or more `ollama*` directories
To enable additional debug logging to help troubleshoot problems, first **Quit the running app from the tray menu** then in a powershell terminal
```powershell
$env:OLLAMA_DEBUG="1"
& "ollama app.exe"
@@ -50,13 +49,12 @@ Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v11 rocm_v5]
You can set OLLAMA_LLM_LIBRARY to any of the available LLM libraries to bypass autodetection, so for example, if you have a CUDA card, but want to force the CPU LLM library with AVX2 vector support, use:
```shell
```
OLLAMA_LLM_LIBRARY="cpu_avx2" ollama serve
```
You can see what features your CPU has with the following.
```shell
```
cat /proc/cpuinfo| grep flags | head -1
```
@@ -64,13 +62,13 @@ cat /proc/cpuinfo| grep flags | head -1
If you run into problems on Linux and want to install an older version, or you'd like to try out a pre-release before it's officially released, you can tell the install script which version to install.
```shell
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.5.7 sh
```sh
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION="0.1.29" sh
```
## Linux docker
## Linux tmp noexec
If Ollama initially works on the GPU in a docker container, but then switches to running on CPU after some period of time with errors in the server log reporting GPU discovery failures, this can be resolved by disabling systemd cgroup management in Docker. Edit `/etc/docker/daemon.json` on the host and add `"exec-opts": ["native.cgroupdriver=cgroupfs"]` to the docker configuration.
If your system is configured with the "noexec" flag where Ollama stores its temporary executable files, you can specify an alternate location by setting OLLAMA_TMPDIR to a location writable by the user ollama runs as. For example OLLAMA_TMPDIR=/usr/share/ollama/
## NVIDIA GPU Discovery
@@ -99,6 +97,8 @@ On linux, AMD GPU access typically requires `video` and/or `render` group member
When running in a container, in some Linux distributions and container runtimes, the ollama process may be unable to access the GPU. Use `ls -lnd /dev/kfd /dev/dri /dev/dri/*` on the host system to determine the **numeric** group IDs on your system, and pass additional `--group-add ...` arguments to the container so it can access the required devices. For example, in the following output `crw-rw---- 1 0 44 226, 0 Sep 16 16:55 /dev/dri/card0` the group ID column is `44`
If Ollama initially works on the GPU in a docker container, but then switches to running on CPU after some period of time with errors in the server log reporting GPU discovery failures, this can be resolved by disabling systemd cgroup management in Docker. Edit `/etc/docker/daemon.json` on the host and add `"exec-opts": ["native.cgroupdriver=cgroupfs"]` to the docker configuration.
If you are experiencing problems getting Ollama to correctly discover or use your GPU for inference, the following may help isolate the failure.
- `AMD_LOG_LEVEL=3` Enable info log levels in the AMD HIP/ROCm libraries. This can help show more detailed error codes that can help troubleshoot problems
- `OLLAMA_DEBUG=1` During GPU discovery additional information will be reported

View File

@@ -47,7 +47,6 @@ If Ollama is already running, Quit the tray application and relaunch it from the
## API Access
Here's a quick example showing API access from `powershell`
```powershell
(Invoke-WebRequest -method POST -Body '{"model":"llama3.2", "prompt":"Why is the sky blue?", "stream": false}' -uri http://localhost:11434/api/generate ).Content | ConvertFrom-json
```
@@ -55,13 +54,14 @@ Here's a quick example showing API access from `powershell`
## Troubleshooting
Ollama on Windows stores files in a few different locations. You can view them in
the explorer window by hitting `<Ctrl>+R` and type in:
the explorer window by hitting `<cmd>+R` and type in:
- `explorer %LOCALAPPDATA%\Ollama` contains logs, and downloaded updates
- *app.log* contains most resent logs from the GUI application
- *server.log* contains the most recent server logs
- *upgrade.log* contains log output for upgrades
- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
- `explorer %HOMEPATH%\.ollama` contains models and configuration
- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories
## Uninstall
@@ -80,11 +80,9 @@ help you keep up to date.
If you'd like to install or integrate Ollama as a service, a standalone
`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI
and GPU library dependencies for Nvidia. If you have an AMD GPU, also download
and extract the additional ROCm package `ollama-windows-amd64-rocm.zip` into the
same directory. This allows for embedding Ollama in existing applications, or
running it as a system service via `ollama serve` with tools such as
[NSSM](https://nssm.cc/).
and GPU library dependencies for Nvidia and AMD. This allows for embedding
Ollama in existing applications, or running it as a system service via `ollama
serve` with tools such as [NSSM](https://nssm.cc/).
> [!NOTE]
> If you are upgrading from a prior version, you should remove the old directories first.

View File

@@ -53,8 +53,8 @@ func Host() *url.URL {
}
}
// AllowedOrigins returns a list of allowed origins. AllowedOrigins can be configured via the OLLAMA_ORIGINS environment variable.
func AllowedOrigins() (origins []string) {
// Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable.
func Origins() (origins []string) {
if s := Var("OLLAMA_ORIGINS"); s != "" {
origins = strings.Split(s, ",")
}
@@ -73,7 +73,6 @@ func AllowedOrigins() (origins []string) {
"file://*",
"tauri://*",
"vscode-webview://*",
"vscode-file://*",
)
return origins
@@ -167,9 +166,7 @@ var (
// MultiUserCache optimizes prompt caching for multi-user scenarios
MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE")
// Enable the new Ollama engine
NewEngine = Bool("OLLAMA_NEW_ENGINE")
// ContextLength sets the default context length
ContextLength = Int64("OLLAMA_CONTEXT_LENGTH", -1)
NewRunners = Bool("OLLAMA_NEW_RUNNERS")
)
func String(s string) func() string {
@@ -227,20 +224,6 @@ func Uint64(key string, defaultValue uint64) func() uint64 {
}
}
func Int64(key string, defaultValue int64) func() int64 {
return func() int64 {
if s := Var(key); s != "" {
if n, err := strconv.ParseInt(s, 10, 64); err != nil {
slog.Warn("invalid environment variable, using default", "key", key, "value", s, "default", defaultValue)
} else {
return n
}
}
return defaultValue
}
}
// Set aside VRAM per GPU
var GpuOverhead = Uint64("OLLAMA_GPU_OVERHEAD", 0)
@@ -266,11 +249,10 @@ func AsMap() map[string]EnvVar {
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default 4096 or 2048 with low VRAM)"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
"OLLAMA_NEW_RUNNERS": {"OLLAMA_NEW_RUNNERS", NewRunners(), "Enable the new Ollama engine"},
// Informational
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
@@ -309,3 +291,12 @@ func Values() map[string]string {
func Var(key string) string {
return strings.Trim(strings.TrimSpace(os.Getenv(key)), "\"'")
}
// On windows, we keep the binary at the top directory, but
// other platforms use a "bin" directory, so this returns ".."
func LibRelativeToExe() string {
if runtime.GOOS == "windows" {
return "."
}
return ".."
}

View File

@@ -69,7 +69,6 @@ func TestOrigins(t *testing.T) {
"file://*",
"tauri://*",
"vscode-webview://*",
"vscode-file://*",
}},
{"http://10.0.0.1", []string{
"http://10.0.0.1",
@@ -89,7 +88,6 @@ func TestOrigins(t *testing.T) {
"file://*",
"tauri://*",
"vscode-webview://*",
"vscode-file://*",
}},
{"http://172.16.0.1,https://192.168.0.1", []string{
"http://172.16.0.1",
@@ -110,7 +108,6 @@ func TestOrigins(t *testing.T) {
"file://*",
"tauri://*",
"vscode-webview://*",
"vscode-file://*",
}},
{"http://totally.safe,http://definitely.legit", []string{
"http://totally.safe",
@@ -131,14 +128,13 @@ func TestOrigins(t *testing.T) {
"file://*",
"tauri://*",
"vscode-webview://*",
"vscode-file://*",
}},
}
for _, tt := range cases {
t.Run(tt.value, func(t *testing.T) {
t.Setenv("OLLAMA_ORIGINS", tt.value)
if diff := cmp.Diff(AllowedOrigins(), tt.expect); diff != "" {
if diff := cmp.Diff(Origins(), tt.expect); diff != "" {
t.Errorf("%s: mismatch (-want +got):\n%s", tt.value, diff)
}
})
@@ -276,19 +272,3 @@ func TestVar(t *testing.T) {
})
}
}
func TestContextLength(t *testing.T) {
cases := map[string]int64{
"": -1,
"4096": 4096,
}
for k, v := range cases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_CONTEXT_LENGTH", k)
if i := ContextLength(); i != v {
t.Errorf("%s: expected %d, got %d", k, v, i)
}
})
}
}

View File

@@ -40,6 +40,8 @@ func HumanBytes(b int64) string {
}
switch {
case value >= 100:
return fmt.Sprintf("%d %s", int(value), unit)
case value >= 10:
return fmt.Sprintf("%d %s", int(value), unit)
case value != math.Trunc(value):

View File

@@ -1,91 +0,0 @@
package format
import (
"testing"
)
func TestHumanBytes(t *testing.T) {
type testCase struct {
input int64
expected string
}
tests := []testCase{
// Test bytes (B)
{0, "0 B"},
{1, "1 B"},
{999, "999 B"},
// Test kilobytes (KB)
{1000, "1 KB"},
{1500, "1.5 KB"},
{999999, "999 KB"},
// Test megabytes (MB)
{1000000, "1 MB"},
{1500000, "1.5 MB"},
{999999999, "999 MB"},
// Test gigabytes (GB)
{1000000000, "1 GB"},
{1500000000, "1.5 GB"},
{999999999999, "999 GB"},
// Test terabytes (TB)
{1000000000000, "1 TB"},
{1500000000000, "1.5 TB"},
{1999999999999, "2.0 TB"},
// Test fractional values
{1234, "1.2 KB"},
{1234567, "1.2 MB"},
{1234567890, "1.2 GB"},
}
for _, tc := range tests {
t.Run(tc.expected, func(t *testing.T) {
result := HumanBytes(tc.input)
if result != tc.expected {
t.Errorf("Expected %s, got %s", tc.expected, result)
}
})
}
}
func TestHumanBytes2(t *testing.T) {
type testCase struct {
input uint64
expected string
}
tests := []testCase{
// Test bytes (B)
{0, "0 B"},
{1, "1 B"},
{1023, "1023 B"},
// Test kibibytes (KiB)
{1024, "1.0 KiB"},
{1536, "1.5 KiB"},
{1048575, "1024.0 KiB"},
// Test mebibytes (MiB)
{1048576, "1.0 MiB"},
{1572864, "1.5 MiB"},
{1073741823, "1024.0 MiB"},
// Test gibibytes (GiB)
{1073741824, "1.0 GiB"},
{1610612736, "1.5 GiB"},
{2147483648, "2.0 GiB"},
}
for _, tc := range tests {
t.Run(tc.expected, func(t *testing.T) {
result := HumanBytes2(tc.input)
if result != tc.expected {
t.Errorf("Expected %s, got %s", tc.expected, result)
}
})
}
}

View File

@@ -12,9 +12,6 @@ func TestHumanNumber(t *testing.T) {
testCases := []testCase{
{0, "0"},
{999, "999"},
{1000, "1K"},
{1001, "1K"},
{1000000, "1M"},
{125000000, "125M"},
{500500000, "500.50M"},

View File

@@ -5,7 +5,7 @@ import (
"time"
)
func assertEqual(t *testing.T, a any, b any) {
func assertEqual(t *testing.T, a interface{}, b interface{}) {
if a != b {
t.Errorf("Assert failed, expected %v, got %v", b, a)
}

View File

@@ -1,13 +0,0 @@
package fs
type Config interface {
Architecture() string
String(string, ...string) string
Uint(string, ...uint32) uint32
Float(string, ...float32) float32
Bool(string, ...bool) bool
Strings(string, ...[]string) []string
Ints(string, ...[]int32) []int32
Floats(string, ...[]float32) []float32
}

View File

@@ -33,7 +33,7 @@ func (kv KV) Kind() string {
}
func (kv KV) ParameterCount() uint64 {
return keyValue(kv, "general.parameter_count", uint64(0))
return keyValue[uint64](kv, "general.parameter_count")
}
func (kv KV) FileType() fileType {
@@ -100,47 +100,27 @@ func (kv KV) Float(key string, defaultValue ...float32) float32 {
return keyValue(kv, key, append(defaultValue, 0)...)
}
func (kv KV) Bool(key string, defaultValue ...bool) bool {
return keyValue(kv, key, append(defaultValue, false)...)
}
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
return keyValue(kv, key, &array[string]{values: append(defaultValue, []string(nil))[0]}).values
}
r := keyValue(kv, key, &array{})
s := make([]string, r.size)
for i := range r.size {
s[i] = r.values[i].(string)
}
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
return keyValue(kv, key, &array[int32]{values: append(defaultValue, []int32(nil))[0]}).values
return s
}
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
return keyValue(kv, key, &array[uint32]{values: append(defaultValue, []uint32(nil))[0]}).values
r := keyValue(kv, key, &array{})
s := make([]uint32, r.size)
for i := range r.size {
s[i] = uint32(r.values[i].(int32))
}
return s
}
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
return keyValue(kv, key, &array[float32]{values: append(defaultValue, []float32(nil))[0]}).values
}
func (kv KV) OllamaEngineRequired() bool {
return slices.Contains([]string{
"gemma3",
"mistral3",
"llama4",
}, kv.Architecture())
}
type valueTypes interface {
uint8 | int8 | uint16 | int16 |
uint32 | int32 | uint64 | int64 |
string | float32 | float64 | bool
}
type arrayValueTypes interface {
*array[uint8] | *array[int8] | *array[uint16] | *array[int16] |
*array[uint32] | *array[int32] | *array[uint64] | *array[int64] |
*array[string] | *array[float32] | *array[float64] | *array[bool]
}
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) T {
func keyValue[T string | uint32 | uint64 | float32 | *array](kv KV, key string, defaultValue ...T) T {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key
}
@@ -173,17 +153,19 @@ func (s Tensors) Items(prefix ...string) []*Tensor {
return items
}
func (ts Tensors) GroupLayers() map[string]Layer {
func (ts Tensors) Layers() map[string]Layer {
layers := make(map[string]Layer)
for _, t := range ts.items {
parts := strings.Split(t.Name, ".")
if index := slices.IndexFunc(parts, func(s string) bool { return s == "blk" || s == "mm" }); index != -1 {
if len(parts) > index+2 {
// blk and mm should have a number after them, join it
parts = append(
[]string{strings.Join(parts[:index+2], ".")},
parts[index+2:]...)
}
if i := slices.Index(parts, "blk"); i > 0 {
parts = append([]string{
strings.Join(parts[:i], "."),
strings.Join(parts[i:i+2], "."),
}, parts[i+2:]...)
} else if i == 0 {
parts = append([]string{
strings.Join(parts[i:i+2], "."),
}, parts[i+2:]...)
}
if _, ok := layers[parts[0]]; !ok {
@@ -227,26 +209,11 @@ func (t Tensor) block() (n int) {
func (t Tensor) blockSize() uint64 {
switch t.Kind {
case
0, // F32
1, // F16
24, // I8
25, // I16
26, // I32
27, // I64
28, // F64
30: // BF16
case 0, 1, 24, 25, 26, 27, 28, 30: // F32, F16, I8, I16, I32, I64, F64, BF16
return 1
case
2, // Q4_0
3, // Q4_1
6, // Q5_0
7, // Q5_1
8, // Q8_0
9, // Q8_1
20: // IQ4_NL
case 2, 3, 4, 5, 6, 7, 8, 9, 20: // Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, Q8_1, IQ4_NL
return 32
default:
default: // All others
return 256
}
}
@@ -270,7 +237,7 @@ func (t Tensor) typeSize() uint64 {
case 8: // Q8_0
return 2 + blockSize
case 9: // Q8_1
return 2 + 2 + blockSize
return 4 + 4 + blockSize
case 10: // Q2_K
return blockSize/16 + blockSize/4 + 2 + 2
case 11: // Q3_K
@@ -282,7 +249,7 @@ func (t Tensor) typeSize() uint64 {
case 14: // Q6_K
return blockSize/2 + blockSize/4 + blockSize/16 + 2
case 15: // Q8_K
return 4 + blockSize + 2*blockSize/16
return 2 + blockSize + 2*blockSize/16
case 16: // IQ2_XXS
return 2 + 2*blockSize/8
case 17: // IQ2_XS
@@ -311,8 +278,6 @@ func (t Tensor) typeSize() uint64 {
return 8
case 29: // IQ1_M
return blockSize/8 + blockSize/16 + blockSize/32
case 30: // BF16
return 2
default:
return 0
}
@@ -330,10 +295,6 @@ func (t Tensor) Size() uint64 {
return t.parameters() * t.typeSize() / t.blockSize()
}
func (t Tensor) Type() string {
return fileType(t.Kind).String()
}
type container interface {
Name() string
Decode(io.ReadSeeker) (model, error)
@@ -375,8 +336,13 @@ func DetectContentType(b []byte) string {
// Decode decodes a GGML model from the given reader.
//
// It collects array values for arrays with a size less than or equal to
// maxArraySize. If the maxArraySize is negative, all arrays are collected.
// maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
// the maxArraySize is negative, all arrays are collected.
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
if maxArraySize == 0 {
maxArraySize = 1024
}
rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
var magic uint32
@@ -411,26 +377,23 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
}, offset, nil
}
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
embedding := f.KV().EmbeddingLength()
heads := f.KV().HeadCount()
headsKV := f.KV().HeadCountKV()
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size)
func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
embedding := llm.KV().EmbeddingLength()
heads := llm.KV().HeadCount()
headsKV := llm.KV().HeadCountKV()
vocab := uint64(llm.KV()["tokenizer.ggml.tokens"].(*array).size)
embeddingHeads := f.KV().EmbeddingHeadCount()
embeddingHeadsK := f.KV().EmbeddingHeadCountK()
embeddingHeadsV := f.KV().EmbeddingHeadCountV()
embeddingHeads := llm.KV().EmbeddingHeadCount()
embeddingHeadsK := llm.KV().EmbeddingHeadCountK()
embeddingHeadsV := llm.KV().EmbeddingHeadCountV()
layers := f.Tensors().GroupLayers()
layers := llm.Tensors().Layers()
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
kv = make([]uint64, f.KV().BlockCount())
for i := range kv {
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
}
kv = uint64(float64(context*llm.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
switch f.KV().Architecture() {
case "llama", "llama4":
switch llm.KV().Architecture() {
case "llama":
fullOffload = max(
4*batch*(1+4*embedding+context*(1+heads)),
4*batch*(embedding+vocab),
@@ -444,7 +407,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
// mixtral 8x22b
ff := uint64(f.KV().Uint("feed_forward_length"))
ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
partialOffload = max(
3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
@@ -461,14 +424,16 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
case "mllama":
var visionTokens, tiles uint64 = 1601, 4
crossAttentionLayers := f.KV().Ints("attention.cross_attention_layers")
for i := range kv {
if slices.Contains(crossAttentionLayers, int32(i)) {
kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) *
4 * // sizeof(float32)
visionTokens *
tiles
}
if crossAttentionLayers, ok := llm.KV()["mllama.attention.cross_attention_layers"].(*array); ok {
kv = headsKV *
(embeddingHeadsK + embeddingHeadsV) * // one for K, one for V
(2* // sizeof(float16)
(llm.KV().BlockCount()-uint64(crossAttentionLayers.size))* // num non-cross attention layers
context +
4* // sizeof(float32)
uint64(crossAttentionLayers.size)* // num cross attention layers
visionTokens*
tiles)
}
fullOffload = max(
@@ -478,7 +443,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
)
var ropeFreqsCount uint64
if ropeFreqs, ok := f.Tensors().GroupLayers()["rope_freqs"]; ok {
if ropeFreqs, ok := llm.Tensors().Layers()["rope_freqs"]; ok {
if ropeFreqsWeights, ok := ropeFreqs["weights"]; ok {
ropeFreqsCount = ropeFreqsWeights.parameters()
}
@@ -492,7 +457,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
// vocab graph
4*batch*(embedding+vocab)+embedding*vocab*105/128,
)
case "gemma", "gemma2", "gemma3":
case "gemma", "gemma2":
fullOffload = max(
4*batch*(embedding+vocab),
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
@@ -504,20 +469,6 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
4*embeddingHeadsK*context*8+
embedding*embeddingHeadsK*heads*9/16,
)
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
// engine. Gemma3 always uses the Ollama engine.
if f.KV().Architecture() == "gemma3" {
const gemma3GlobalCacheCount = 6
slidingWindow := (uint64(numParallel) * uint64(f.KV().Uint("attention.sliding_window"))) + batch
for i := range kv {
// Every 6th layer is a global layer, which is the full context size that has already been set. The other
// layers are the smaller local (sliding) layers.
if (i+1)%gemma3GlobalCacheCount != 0 {
kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
}
}
}
case "command-r":
fullOffload = max(
4*batch*(embedding+vocab),
@@ -595,74 +546,21 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
return
}
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
if llm.KV().Uint("vision.block_count") == 0 {
return
}
for name, layer := range llm.Tensors().GroupLayers() {
if name == "v" || strings.HasPrefix(name, "v.") {
for _, tensor := range layer {
weights += tensor.Size()
}
}
}
imageSize := uint64(llm.KV().Uint("vision.image_size"))
patchSize := uint64(llm.KV().Uint("vision.patch_size"))
if patchSize == 0 {
slog.Warn("unknown patch size for vision model")
return
}
numChannels := uint64(llm.KV().Uint("vision.num_channels"))
numPatches := (imageSize / patchSize) * (imageSize / patchSize)
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
numPatches++
}
headCount := uint64(llm.KV().Uint("vision.attention.head_count"))
embeddingLength := uint64(llm.KV().Uint("vision.embedding_length"))
switch llm.KV().Architecture() {
case "mllama":
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
maxNumTiles := uint64(llm.KV().Uint("vision.max_num_tiles"))
graphSize = 4 * (8 +
imageSize*imageSize*numChannels*maxNumTiles +
embeddingLength*numPatches*maxNumTiles +
9*embeddingLength*numPaddedPatches*maxNumTiles +
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
case "gemma3", "mistral3":
graphSize = 4 * (imageSize*imageSize*numChannels +
embeddingLength*patchSize +
numPatches*numPatches*headCount)
case "llama4":
// vision graph is computed independently in the same schedule
// and is negligible compared to the worst case text graph
}
return weights, graphSize
}
// SupportsKVCacheType checks if the requested cache type is supported
func (f GGML) SupportsKVCacheType(cacheType string) bool {
func (llm GGML) SupportsKVCacheType(cacheType string) bool {
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)
}
// SupportsFlashAttention checks if the model supports flash attention
func (f GGML) SupportsFlashAttention() bool {
_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]
func (llm GGML) SupportsFlashAttention() bool {
_, isEmbedding := llm.KV()[fmt.Sprintf("%s.pooling_type", llm.KV().Architecture())]
if isEmbedding {
return false
}
// Check head counts match and are non-zero
headCountK := f.KV().EmbeddingHeadCountK()
headCountV := f.KV().EmbeddingHeadCountV()
headCountK := llm.KV().EmbeddingHeadCountK()
headCountV := llm.KV().EmbeddingHeadCountV()
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
}

View File

@@ -1,271 +0,0 @@
package ggml
import (
"maps"
"math"
"slices"
"strconv"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestTensorLayers(t *testing.T) {
tensors := make(map[string]*Tensor)
for _, name := range []string{
"token_embd.weight",
"blk.0.attn_k.weight",
"blk.0.attn_output.weight",
"blk.0.attn_q.weight",
"blk.0.attn_v.weight",
"blk.0.attn_norm.weight",
"blk.0.ffn_down.weight",
"blk.0.ffn_gate.weight",
"blk.0.ffn_up.weight",
"blk.0.ffn_norm.weight",
"output_norm.weight",
"mm.0.bias",
"mm.0.weight",
"v.blk.0.attn_k.weight",
"v.blk.0.attn_output.weight",
"v.blk.0.attn_q.weight",
"v.blk.0.attn_v.weight",
"v.blk.0.attn_norm.weight",
"v.blk.0.ffn_down.weight",
"v.blk.0.ffn_gate.weight",
"v.blk.0.ffn_up.weight",
"v.blk.0.ffn_norm.weight",
"v.patch_embd.weight",
"v.position_embd.gate",
"v.position_embd.weight",
} {
tensors[name] = &Tensor{Name: name}
}
cases := []struct {
name string
items []*Tensor
want map[string]Layer
}{
{
name: "text",
items: slices.Collect(func(yield func(*Tensor) bool) {
for k, v := range tensors {
if !strings.HasPrefix(k, "mm.") && !strings.HasPrefix(k, "v.") {
if !yield(v) {
return
}
}
}
}),
want: map[string]Layer{
"blk.0": {
"attn_k.weight": tensors["blk.0.attn_k.weight"],
"attn_q.weight": tensors["blk.0.attn_q.weight"],
"attn_v.weight": tensors["blk.0.attn_v.weight"],
"attn_output.weight": tensors["blk.0.attn_output.weight"],
"attn_norm.weight": tensors["blk.0.attn_norm.weight"],
"ffn_down.weight": tensors["blk.0.ffn_down.weight"],
"ffn_gate.weight": tensors["blk.0.ffn_gate.weight"],
"ffn_up.weight": tensors["blk.0.ffn_up.weight"],
"ffn_norm.weight": tensors["blk.0.ffn_norm.weight"],
},
"token_embd": {"weight": tensors["token_embd.weight"]},
"output_norm": {"weight": tensors["output_norm.weight"]},
},
},
{
name: "vision",
items: slices.Collect(func(yield func(*Tensor) bool) {
for k, v := range tensors {
if strings.HasPrefix(k, "mm.") || strings.HasPrefix(k, "v.") {
if !yield(v) {
return
}
}
}
}),
want: map[string]Layer{
"mm.0": {
"bias": tensors["mm.0.bias"],
"weight": tensors["mm.0.weight"],
},
"v.blk.0": {
"attn_k.weight": tensors["v.blk.0.attn_k.weight"],
"attn_q.weight": tensors["v.blk.0.attn_q.weight"],
"attn_v.weight": tensors["v.blk.0.attn_v.weight"],
"attn_output.weight": tensors["v.blk.0.attn_output.weight"],
"attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
"ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
"ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
"ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
"ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
},
"v": {
"patch_embd.weight": tensors["v.patch_embd.weight"],
"position_embd.gate": tensors["v.position_embd.gate"],
"position_embd.weight": tensors["v.position_embd.weight"],
},
},
},
{
name: "vision and text",
items: slices.Collect(maps.Values(tensors)),
want: map[string]Layer{
"blk.0": {
"attn_k.weight": tensors["blk.0.attn_k.weight"],
"attn_q.weight": tensors["blk.0.attn_q.weight"],
"attn_v.weight": tensors["blk.0.attn_v.weight"],
"attn_output.weight": tensors["blk.0.attn_output.weight"],
"attn_norm.weight": tensors["blk.0.attn_norm.weight"],
"ffn_down.weight": tensors["blk.0.ffn_down.weight"],
"ffn_gate.weight": tensors["blk.0.ffn_gate.weight"],
"ffn_up.weight": tensors["blk.0.ffn_up.weight"],
"ffn_norm.weight": tensors["blk.0.ffn_norm.weight"],
},
"token_embd": {"weight": tensors["token_embd.weight"]},
"output_norm": {"weight": tensors["output_norm.weight"]},
"mm.0": {
"bias": tensors["mm.0.bias"],
"weight": tensors["mm.0.weight"],
},
"v.blk.0": {
"attn_k.weight": tensors["v.blk.0.attn_k.weight"],
"attn_q.weight": tensors["v.blk.0.attn_q.weight"],
"attn_v.weight": tensors["v.blk.0.attn_v.weight"],
"attn_output.weight": tensors["v.blk.0.attn_output.weight"],
"attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
"ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
"ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
"ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
"ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
},
"v": {
"patch_embd.weight": tensors["v.patch_embd.weight"],
"position_embd.gate": tensors["v.position_embd.gate"],
"position_embd.weight": tensors["v.position_embd.weight"],
},
},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got := Tensors{items: tt.items}.GroupLayers()
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("unexpected layers (-got +want):\n%s", diff)
}
})
}
}
// ref: https://github.com/ggml-org/llama.cpp/blob/a82c9e7c23ef6db48cebfa194dc9cebbc4ac3552/ggml/src/ggml.c#L572
func TestTensorTypes(t *testing.T) {
cases := []struct {
kind uint32
blockSize uint64
typeSize uint64
}{
{0, 1, 4},
{1, 1, 2},
{2, 32, 18},
{3, 32, 20},
{6, 32, 22},
{7, 32, 24},
{8, 32, 34},
{9, 32, 36},
{10, 256, 84},
{11, 256, 110},
{12, 256, 144},
{13, 256, 176},
{14, 256, 210},
{15, 256, 292},
{16, 256, 66},
{17, 256, 74},
{18, 256, 98},
{19, 256, 50},
{20, 32, 18},
{21, 256, 110},
{22, 256, 82},
{23, 256, 136},
{24, 1, 1},
{25, 1, 2},
{26, 1, 4},
{27, 1, 8},
{28, 1, 8},
{29, 256, 56},
{30, 1, 2},
}
for _, tt := range cases {
t.Run(strconv.Itoa(int(tt.kind)), func(t *testing.T) {
tensor := Tensor{Kind: tt.kind}
if tensor.blockSize() != tt.blockSize {
t.Errorf("unexpected block size: got=%d want=%d", tensor.blockSize(), tt.blockSize)
}
if tensor.typeSize() != tt.typeSize {
t.Errorf("unexpected type size: got=%d want=%d", tensor.typeSize(), tt.typeSize)
}
})
}
}
func TestKeyValue(t *testing.T) {
kv := KV{
"general.architecture": "test",
"test.strings": &array[string]{size: 3, values: []string{"a", "b", "c"}},
"test.float32s": &array[float32]{size: 3, values: []float32{1.0, 2.0, 3.0}},
"test.int32s": &array[int32]{size: 3, values: []int32{1, 2, 3}},
"test.uint32s": &array[uint32]{size: 3, values: []uint32{1, 2, 3}},
}
if diff := cmp.Diff(kv.Strings("strings"), []string{"a", "b", "c"}); diff != "" {
t.Errorf("unexpected strings (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Strings("nonexistent.strings"), []string(nil)); diff != "" {
t.Errorf("unexpected strings (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Strings("default.strings", []string{"ollama"}), []string{"ollama"}); diff != "" {
t.Errorf("unexpected strings (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Floats("float32s"), []float32{1.0, 2.0, 3.0}); diff != "" {
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Floats("nonexistent.float32s"), []float32(nil)); diff != "" {
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Floats("default.float32s", []float32{math.MaxFloat32}), []float32{math.MaxFloat32}); diff != "" {
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Ints("int32s"), []int32{1, 2, 3}); diff != "" {
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Ints("nonexistent.int32s"), []int32(nil)); diff != "" {
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Ints("default.int32s", []int32{math.MaxInt32}), []int32{math.MaxInt32}); diff != "" {
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Uints("uint32s"), []uint32{1, 2, 3}); diff != "" {
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Uints("nonexistent.uint32s"), []uint32(nil)); diff != "" {
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Uints("default.uint32s", []uint32{math.MaxUint32}), []uint32{math.MaxUint32}); diff != "" {
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
}
}

View File

@@ -36,6 +36,10 @@ type containerGGUF struct {
maxArraySize int
}
func (c *containerGGUF) canCollectArray(size int) bool {
return c.maxArraySize < 0 || size <= c.maxArraySize
}
func (c *containerGGUF) Name() string {
return "gguf"
}
@@ -231,7 +235,10 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
// patch KV with parameter count
llm.kv["general.parameter_count"] = llm.parameters
alignment := llm.kv.Uint("general.alignment", 32)
alignment, ok := llm.kv["general.alignment"].(uint32)
if !ok {
alignment = 32
}
offset, err := rs.Seek(0, io.SeekCurrent)
if err != nil {
@@ -291,23 +298,6 @@ func readGGUFV1String(llm *gguf, r io.Reader) (string, error) {
return b.String(), nil
}
func readGGUFV1StringsData(llm *gguf, r io.Reader, a *array[string]) (any, error) {
for i := range a.size {
if a.values != nil {
e, err := readGGUFV1String(llm, r)
if err != nil {
return nil, err
}
a.values[i] = e
} else {
discardGGUFString(llm, r)
}
}
return a, nil
}
func discardGGUFString(llm *gguf, r io.Reader) error {
buf := llm.scratch[:8]
_, err := io.ReadFull(r, buf)
@@ -365,44 +355,78 @@ func writeGGUFString(w io.Writer, s string) error {
return err
}
func readGGUFStringsData(llm *gguf, r io.Reader, a *array[string]) (any, error) {
for i := range a.size {
if a.values != nil {
e, err := readGGUFString(llm, r)
if err != nil {
return nil, err
}
type array struct {
size int
values []any
}
func (a *array) MarshalJSON() ([]byte, error) {
return json.Marshal(a.values)
}
func readGGUFV1Array(llm *gguf, r io.Reader) (*array, error) {
t, err := readGGUF[uint32](llm, r)
if err != nil {
return nil, err
}
n, err := readGGUF[uint32](llm, r)
if err != nil {
return nil, err
}
a := &array{size: int(n)}
if llm.canCollectArray(int(n)) {
a.values = make([]any, 0, int(n))
}
for i := range n {
var e any
switch t {
case ggufTypeUint8:
e, err = readGGUF[uint8](llm, r)
case ggufTypeInt8:
e, err = readGGUF[int8](llm, r)
case ggufTypeUint16:
e, err = readGGUF[uint16](llm, r)
case ggufTypeInt16:
e, err = readGGUF[int16](llm, r)
case ggufTypeUint32:
e, err = readGGUF[uint32](llm, r)
case ggufTypeInt32:
e, err = readGGUF[int32](llm, r)
case ggufTypeUint64:
e, err = readGGUF[uint64](llm, r)
case ggufTypeInt64:
e, err = readGGUF[int64](llm, r)
case ggufTypeFloat32:
e, err = readGGUF[float32](llm, r)
case ggufTypeFloat64:
e, err = readGGUF[float64](llm, r)
case ggufTypeBool:
e, err = readGGUF[bool](llm, r)
case ggufTypeString:
e, err = readGGUFV1String(llm, r)
default:
return nil, fmt.Errorf("invalid array type: %d", t)
}
if err != nil {
return nil, err
}
if a.values != nil {
a.values[i] = e
} else {
discardGGUFString(llm, r)
}
}
return a, nil
}
type array[T any] struct {
// size is the actual size of the array
size int
// values is the array of values. this is nil if the array is larger than configured maxSize
values []T
}
func (a *array[T]) MarshalJSON() ([]byte, error) {
return json.Marshal(a.values)
}
func newArray[T any](size, maxSize int) *array[T] {
a := array[T]{size: size}
if maxSize < 0 || size <= maxSize {
a.values = make([]T, size)
func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
if llm.Version == 1 {
return readGGUFV1Array(llm, r)
}
return &a
}
func readGGUFArray(llm *gguf, r io.Reader) (any, error) {
t, err := readGGUF[uint32](llm, r)
if err != nil {
return nil, err
@@ -413,55 +437,45 @@ func readGGUFArray(llm *gguf, r io.Reader) (any, error) {
return nil, err
}
switch t {
case ggufTypeUint8:
a := newArray[uint8](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeInt8:
a := newArray[int8](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeUint16:
a := newArray[uint16](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeInt16:
a := newArray[int16](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeUint32:
a := newArray[uint32](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeInt32:
a := newArray[int32](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeUint64:
a := newArray[uint64](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeInt64:
a := newArray[int64](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeFloat32:
a := newArray[float32](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeFloat64:
a := newArray[float64](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeBool:
a := newArray[bool](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeString:
a := newArray[string](int(n), llm.maxArraySize)
if llm.Version == 1 {
return readGGUFV1StringsData(llm, r, a)
}
return readGGUFStringsData(llm, r, a)
default:
return nil, fmt.Errorf("invalid array type: %d", t)
a := &array{size: int(n)}
if llm.canCollectArray(int(n)) {
a.values = make([]any, int(n))
}
}
func readGGUFArrayData[T any](llm *gguf, r io.Reader, a *array[T]) (any, error) {
for i := range a.size {
e, err := readGGUF[T](llm, r)
for i := range n {
var e any
switch t {
case ggufTypeUint8:
e, err = readGGUF[uint8](llm, r)
case ggufTypeInt8:
e, err = readGGUF[int8](llm, r)
case ggufTypeUint16:
e, err = readGGUF[uint16](llm, r)
case ggufTypeInt16:
e, err = readGGUF[int16](llm, r)
case ggufTypeUint32:
e, err = readGGUF[uint32](llm, r)
case ggufTypeInt32:
e, err = readGGUF[int32](llm, r)
case ggufTypeUint64:
e, err = readGGUF[uint64](llm, r)
case ggufTypeInt64:
e, err = readGGUF[int64](llm, r)
case ggufTypeFloat32:
e, err = readGGUF[float32](llm, r)
case ggufTypeFloat64:
e, err = readGGUF[float64](llm, r)
case ggufTypeBool:
e, err = readGGUF[bool](llm, r)
case ggufTypeString:
if a.values != nil {
e, err = readGGUFString(llm, r)
} else {
err = discardGGUFString(llm, r)
}
default:
return nil, fmt.Errorf("invalid array type: %d", t)
}
if err != nil {
return nil, err
}
@@ -492,8 +506,6 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
}
func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
alignment := kv.Uint("general.alignment", 32)
if err := binary.Write(ws, binary.LittleEndian, []byte("GGUF")); err != nil {
return err
}
@@ -531,15 +543,16 @@ func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
var s uint64
for _, t := range ts {
t.Offset = s + uint64(ggufPadding(int64(s), int64(alignment)))
t.Offset = s
if err := ggufWriteTensorInfo(ws, t); err != nil {
return err
}
s += t.Size()
}
var alignment int64 = 32
for _, t := range ts {
if err := ggufWriteTensor(ws, t, int64(alignment)); err != nil {
if err := ggufWriteTensor(ws, t, alignment); err != nil {
return err
}
}
@@ -616,8 +629,8 @@ func ggufWriteTensorInfo(ws io.WriteSeeker, t Tensor) error {
return err
}
for _, n := range t.Shape {
if err := binary.Write(ws, binary.LittleEndian, n); err != nil {
for i := range len(t.Shape) {
if err := binary.Write(ws, binary.LittleEndian, t.Shape[len(t.Shape)-i-1]); err != nil {
return err
}
}

View File

@@ -32,10 +32,9 @@ const (
fileTypeIQ1_S
fileTypeIQ4_NL
fileTypeIQ3_S
fileTypeIQ3_M
fileTypeIQ2_S
fileTypeIQ2_M
fileTypeIQ4_XS
fileTypeIQ2_M
fileTypeIQ1_M
fileTypeBF16
@@ -94,14 +93,12 @@ func ParseFileType(s string) (fileType, error) {
return fileTypeIQ4_NL, nil
case "IQ3_S":
return fileTypeIQ3_S, nil
case "IQ3_M":
return fileTypeIQ3_M, nil
case "IQ2_S":
return fileTypeIQ2_S, nil
case "IQ2_M":
return fileTypeIQ2_M, nil
case "IQ4_XS":
return fileTypeIQ4_XS, nil
case "IQ2_M":
return fileTypeIQ2_M, nil
case "IQ1_M":
return fileTypeIQ1_M, nil
case "BF16":
@@ -163,8 +160,6 @@ func (t fileType) String() string {
return "IQ4_NL"
case fileTypeIQ3_S:
return "IQ3_S"
case fileTypeIQ3_M:
return "IQ3_M"
case fileTypeIQ2_S:
return "IQ2_S"
case fileTypeIQ4_XS:

19
go.mod
View File

@@ -1,6 +1,6 @@
module github.com/ollama/ollama
go 1.24.0
go 1.23.4
require (
github.com/containerd/console v1.0.3
@@ -11,7 +11,7 @@ require (
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.9.0
github.com/x448/float16 v0.8.4
golang.org/x/sync v0.11.0
golang.org/x/sync v0.10.0
)
require (
@@ -24,7 +24,7 @@ require (
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
golang.org/x/image v0.22.0
golang.org/x/tools v0.30.0
gonum.org/v1/gonum v0.15.0
)
require (
@@ -44,7 +44,6 @@ require (
github.com/xtgo/set v1.0.0 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
gonum.org/v1/gonum v0.15.0 // indirect
gorgonia.org/vecf32 v0.9.0 // indirect
gorgonia.org/vecf64 v0.9.0 // indirect
)
@@ -70,12 +69,12 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.33.0
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
golang.org/x/net v0.35.0 // indirect
golang.org/x/sys v0.30.0
golang.org/x/term v0.29.0
golang.org/x/text v0.22.0
golang.org/x/crypto v0.31.0
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa
golang.org/x/net v0.25.0 // indirect
golang.org/x/sys v0.28.0
golang.org/x/term v0.27.0
golang.org/x/text v0.21.0
google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1 // indirect
)

30
go.sum
View File

@@ -214,16 +214,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4=
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ=
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE=
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
@@ -257,8 +257,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -268,8 +268,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -285,17 +285,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s=
golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -309,8 +309,6 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@@ -1,412 +0,0 @@
//go:build integration
package integration
import (
"bytes"
"context"
"fmt"
"math/rand"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func TestAPIGenerate(t *testing.T) {
initialTimeout := 60 * time.Second
streamTimeout := 30 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()
// Set up the test data
req := api.GenerateRequest{
Model: smol,
Prompt: "why is the sky blue? be brief",
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
anyResp := []string{"rayleigh", "scattering"}
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("pull failed %s", err)
}
tests := []struct {
name string
stream bool
}{
{
name: "stream",
stream: true,
},
{
name: "no_stream",
stream: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
stallTimer := time.NewTimer(initialTimeout)
var buf bytes.Buffer
fn := func(response api.GenerateResponse) error {
// Fields that must always be present
if response.Model == "" {
t.Errorf("response missing model: %#v", response)
}
if response.Done {
// Required fields for final updates:
if response.DoneReason == "" && *req.Stream {
// TODO - is the lack of done reason on non-stream a bug?
t.Errorf("final response missing done_reason: %#v", response)
}
if response.Metrics.TotalDuration == 0 {
t.Errorf("final response missing total_duration: %#v", response)
}
if response.Metrics.LoadDuration == 0 {
t.Errorf("final response missing load_duration: %#v", response)
}
if response.Metrics.PromptEvalDuration == 0 {
t.Errorf("final response missing prompt_eval_duration: %#v", response)
}
if response.Metrics.EvalCount == 0 {
t.Errorf("final response missing eval_count: %#v", response)
}
if response.Metrics.EvalDuration == 0 {
t.Errorf("final response missing eval_duration: %#v", response)
}
if len(response.Context) == 0 {
t.Errorf("final response missing context: %#v", response)
}
// Note: caching can result in no prompt eval count, so this can't be verified reliably
// if response.Metrics.PromptEvalCount == 0 {
// t.Errorf("final response missing prompt_eval_count: %#v", response)
// }
} // else incremental response, nothing to check right now...
buf.Write([]byte(response.Response))
if !stallTimer.Reset(streamTimeout) {
return fmt.Errorf("stall was detected while streaming response, aborting")
}
return nil
}
done := make(chan int)
var genErr error
go func() {
req.Stream = &test.stream
req.Options["seed"] = rand.Int() // bust cache for prompt eval results
genErr = client.Generate(ctx, &req, fn)
done <- 0
}()
select {
case <-stallTimer.C:
if buf.Len() == 0 {
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
} else {
t.Errorf("generate stalled. Response so far:%s", buf.String())
}
case <-done:
if genErr != nil {
t.Fatalf("failed with %s request prompt %s ", req.Model, req.Prompt)
}
// Verify the response contains the expected data
response := buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Errorf("none of %v found in %s", anyResp, response)
}
case <-ctx.Done():
t.Error("outer test context done while waiting for generate")
}
})
}
// Validate PS while we're at it...
resp, err := client.ListRunning(ctx)
if err != nil {
t.Fatalf("list models API error: %s", err)
}
if resp == nil || len(resp.Models) == 0 {
t.Fatalf("list models API returned empty list while model should still be loaded")
}
// Find the model we just loaded and verify some attributes
found := false
for _, model := range resp.Models {
if strings.Contains(model.Name, req.Model) {
found = true
if model.Model == "" {
t.Errorf("model field omitted: %#v", model)
}
if model.Size == 0 {
t.Errorf("size omitted: %#v", model)
}
if model.Digest == "" {
t.Errorf("digest omitted: %#v", model)
}
verifyModelDetails(t, model.Details)
var nilTime time.Time
if model.ExpiresAt == nilTime {
t.Errorf("expires_at omitted: %#v", model)
}
// SizeVRAM could be zero.
}
}
if !found {
t.Errorf("unable to locate running model: %#v", resp)
}
}
func TestAPIChat(t *testing.T) {
initialTimeout := 60 * time.Second
streamTimeout := 30 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()
// Set up the test data
req := api.ChatRequest{
Model: smol,
Messages: []api.Message{
{
Role: "user",
Content: "why is the sky blue? be brief",
},
},
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
anyResp := []string{"rayleigh", "scattering"}
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("pull failed %s", err)
}
tests := []struct {
name string
stream bool
}{
{
name: "stream",
stream: true,
},
{
name: "no_stream",
stream: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
stallTimer := time.NewTimer(initialTimeout)
var buf bytes.Buffer
fn := func(response api.ChatResponse) error {
// Fields that must always be present
if response.Model == "" {
t.Errorf("response missing model: %#v", response)
}
if response.Done {
// Required fields for final updates:
var nilTime time.Time
if response.CreatedAt == nilTime {
t.Errorf("final response missing total_duration: %#v", response)
}
if response.DoneReason == "" {
t.Errorf("final response missing done_reason: %#v", response)
}
if response.Metrics.TotalDuration == 0 {
t.Errorf("final response missing total_duration: %#v", response)
}
if response.Metrics.LoadDuration == 0 {
t.Errorf("final response missing load_duration: %#v", response)
}
if response.Metrics.PromptEvalDuration == 0 {
t.Errorf("final response missing prompt_eval_duration: %#v", response)
}
if response.Metrics.EvalCount == 0 {
t.Errorf("final response missing eval_count: %#v", response)
}
if response.Metrics.EvalDuration == 0 {
t.Errorf("final response missing eval_duration: %#v", response)
}
if response.Metrics.PromptEvalCount == 0 {
t.Errorf("final response missing prompt_eval_count: %#v", response)
}
} // else incremental response, nothing to check right now...
buf.Write([]byte(response.Message.Content))
if !stallTimer.Reset(streamTimeout) {
return fmt.Errorf("stall was detected while streaming response, aborting")
}
return nil
}
done := make(chan int)
var genErr error
go func() {
req.Stream = &test.stream
req.Options["seed"] = rand.Int() // bust cache for prompt eval results
genErr = client.Chat(ctx, &req, fn)
done <- 0
}()
select {
case <-stallTimer.C:
if buf.Len() == 0 {
t.Errorf("chat never started. Timed out after :%s", initialTimeout.String())
} else {
t.Errorf("chat stalled. Response so far:%s", buf.String())
}
case <-done:
if genErr != nil {
t.Fatalf("failed with %s request prompt %v", req.Model, req.Messages)
}
// Verify the response contains the expected data
response := buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Errorf("none of %v found in %s", anyResp, response)
}
case <-ctx.Done():
t.Error("outer test context done while waiting for chat")
}
})
}
}
func TestAPIListModels(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Make sure we have at least one model so an empty list can be considered a failure
if err := PullIfMissing(ctx, client, smol); err != nil {
t.Fatalf("pull failed %s", err)
}
resp, err := client.List(ctx)
if err != nil {
t.Fatalf("unable to list models: %s", err)
}
if len(resp.Models) == 0 {
t.Fatalf("list should not be empty")
}
model := resp.Models[0]
if model.Name == "" {
t.Errorf("first model name empty: %#v", model)
}
var nilTime time.Time
if model.ModifiedAt == nilTime {
t.Errorf("first model modified_at empty: %#v", model)
}
if model.Size == 0 {
t.Errorf("first model size empty: %#v", model)
}
if model.Digest == "" {
t.Errorf("first model digest empty: %#v", model)
}
verifyModelDetails(t, model.Details)
}
func verifyModelDetails(t *testing.T, details api.ModelDetails) {
if details.Format == "" {
t.Errorf("first model details.format empty: %#v", details)
}
if details.Family == "" {
t.Errorf("first model details.family empty: %#v", details)
}
if details.ParameterSize == "" {
t.Errorf("first model details.parameter_size empty: %#v", details)
}
if details.QuantizationLevel == "" {
t.Errorf("first model details.quantization_level empty: %#v", details)
}
}
func TestAPIShowModel(t *testing.T) {
modelName := "llama3.2"
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, modelName); err != nil {
t.Fatalf("pull failed %s", err)
}
resp, err := client.Show(ctx, &api.ShowRequest{Name: modelName})
if err != nil {
t.Fatalf("unable to show model: %s", err)
}
if resp.License == "" {
t.Errorf("%s missing license: %#v", modelName, resp)
}
if resp.Modelfile == "" {
t.Errorf("%s missing modelfile: %#v", modelName, resp)
}
if resp.Parameters == "" {
t.Errorf("%s missing parameters: %#v", modelName, resp)
}
if resp.Template == "" {
t.Errorf("%s missing template: %#v", modelName, resp)
}
// llama3 omits system
verifyModelDetails(t, resp.Details)
// llama3 ommits messages
if len(resp.ModelInfo) == 0 {
t.Errorf("%s missing model_info: %#v", modelName, resp)
}
// llama3 omits projectors
var nilTime time.Time
if resp.ModifiedAt == nilTime {
t.Errorf("%s missing modified_at: %#v", modelName, resp)
}
}
func TestAPIEmbeddings(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req := api.EmbeddingRequest{
Model: "orca-mini",
Prompt: "why is the sky blue?",
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("pull failed %s", err)
}
resp, err := client.Embeddings(ctx, &req)
if err != nil {
t.Fatalf("embeddings call failed %s", err)
}
if len(resp.Embedding) == 0 {
t.Errorf("zero length embedding response")
}
}

View File

@@ -14,15 +14,15 @@ import (
"github.com/stretchr/testify/require"
)
func TestBlueSky(t *testing.T) {
func TestOrcaMiniBlueSky(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
// Set up the test data
req := api.GenerateRequest{
Model: smol,
Model: "orca-mini",
Prompt: "why is the sky blue?",
Stream: &stream,
Options: map[string]any{
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
@@ -31,7 +31,6 @@ func TestBlueSky(t *testing.T) {
}
func TestUnicode(t *testing.T) {
skipUnderMinVRAM(t, 6)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
// Set up the test data
@@ -40,7 +39,7 @@ func TestUnicode(t *testing.T) {
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K",
Prompt: "天空为什么是蓝色的?",
Stream: &stream,
Options: map[string]any{
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
// Workaround deepseek context shifting bug
@@ -62,7 +61,7 @@ func TestExtendedUnicodeOutput(t *testing.T) {
Model: "gemma2:2b",
Prompt: "Output some smily face emoji",
Stream: &stream,
Options: map[string]any{
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
@@ -94,10 +93,10 @@ func TestUnicodeModelDir(t *testing.T) {
defer cancel()
req := api.GenerateRequest{
Model: smol,
Model: "orca-mini",
Prompt: "why is the sky blue?",
Stream: &stream,
Options: map[string]any{
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},

View File

@@ -21,11 +21,11 @@ func TestMultiModelConcurrency(t *testing.T) {
var (
req = [2]api.GenerateRequest{
{
Model: "llama3.2:1b",
Model: "orca-mini",
Prompt: "why is the ocean blue?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
@@ -34,7 +34,7 @@ func TestMultiModelConcurrency(t *testing.T) {
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
@@ -67,7 +67,7 @@ func TestMultiModelConcurrency(t *testing.T) {
wg.Wait()
}
func TestIntegrationConcurrentPredict(t *testing.T) {
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
req, resp := GenerateRequests()
reqLimit := len(req)
iterLimit := 5
@@ -117,9 +117,6 @@ func TestMultiModelStress(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if maxVram < 2*format.GibiByte {
t.Skip("VRAM less than 2G, skipping model stress tests")
}
type model struct {
name string
@@ -128,8 +125,8 @@ func TestMultiModelStress(t *testing.T) {
smallModels := []model{
{
name: "llama3.2:1b",
size: 2876 * format.MebiByte,
name: "orca-mini",
size: 2992 * format.MebiByte,
},
{
name: "phi",

View File

@@ -23,7 +23,7 @@ func TestLongInputContext(t *testing.T) {
Model: "llama2",
Prompt: "Oh, dont speak to me of Austria. Perhaps I dont understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexanders loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I dont believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
Stream: &stream,
Options: map[string]any{
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
"num_ctx": 128,
@@ -50,7 +50,7 @@ func TestContextExhaustion(t *testing.T) {
Model: "llama2",
Prompt: "Write me a story with a ton of emojis?",
Stream: &stream,
Options: map[string]any{
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
"num_ctx": 128,

View File

@@ -12,63 +12,14 @@ import (
"github.com/stretchr/testify/require"
)
func TestVisionModels(t *testing.T) {
skipUnderMinVRAM(t, 6)
type testCase struct {
model string
}
testCases := []testCase{
{
model: "llava:7b",
},
{
model: "llama3.2-vision",
},
{
model: "gemma3",
},
}
for _, v := range testCases {
t.Run(v.model, func(t *testing.T) {
image, err := base64.StdEncoding.DecodeString(imageEncoding)
require.NoError(t, err)
req := api.GenerateRequest{
Model: v.model,
Prompt: "what does the text in this image say?",
Stream: &stream,
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
Images: []api.ImageData{
image,
},
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
resp := "the ollam"
defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model))
// llava models on CPU can be quite slow to start
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
})
}
}
func TestIntegrationSplitBatch(t *testing.T) {
func TestIntegrationLlava(t *testing.T) {
image, err := base64.StdEncoding.DecodeString(imageEncoding)
require.NoError(t, err)
req := api.GenerateRequest{
Model: "gemma3:4b",
// Fill up a chunk of the batch so the image will partially spill over into the next one
System: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed aliquet, justo in malesuada lobortis, odio ligula volutpat quam, quis faucibus ipsum magna quis sapien. Aliquam in venenatis diam, eu viverra magna. Phasellus imperdiet hendrerit volutpat. Vivamus sem ex, facilisis placerat felis non, dictum elementum est. Phasellus aliquam imperdiet lacus, eget placerat ligula sodales vel. Pellentesque nec auctor mi. Curabitur arcu nisi, faucibus eget nunc id, viverra interdum mi. Curabitur ornare ipsum ex, ac euismod ex aliquam in. Vestibulum id magna at purus accumsan fermentum. Proin scelerisque posuere nunc quis interdum. Maecenas sed mollis nisl. Etiam vitae ipsum interdum, placerat est quis, tincidunt velit. Nullam tempor nibh non lorem volutpat efficitur. Cras laoreet diam imperdiet ipsum auctor bibendum. Suspendisse ultrices urna sed metus sagittis suscipit. Quisque ullamcorper aliquam nibh ut mollis. Aenean dapibus mauris pharetra, venenatis elit ac, hendrerit odio. Cras vestibulum erat tempor, lobortis justo eu, lobortis ipsum. Nam laoreet dapibus sem. Proin vel diam ultrices, elementum ante et, ornare lectus. Proin eu accumsan nisl. Praesent ac ex vitae ipsum vulputate tristique facilisis sit amet lacus. Nullam faucibus magna a pellentesque pretium. Nunc lacinia ullamcorper sollicitudin. Donec vitae accumsan turpis, sed porttitor est. Donec porttitor mi vitae augue faucibus, vel mollis diam tincidunt.",
Model: "llava:7b",
Prompt: "what does the text in this image say?",
Stream: &stream,
Options: map[string]any{
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
@@ -88,6 +39,33 @@ func TestIntegrationSplitBatch(t *testing.T) {
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
}
func TestIntegrationMllama(t *testing.T) {
image, err := base64.StdEncoding.DecodeString(imageEncoding)
require.NoError(t, err)
req := api.GenerateRequest{
// TODO fix up once we publish the final image
Model: "x/llama3.2-vision",
Prompt: "what does the text in this image say?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
Images: []api.ImageData{
image,
},
}
resp := "the ollamas"
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model))
// mllama models on CPU can be quite slow to start,
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
}
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
AAUAAAABAAAAUgEoAAMAAAABAAIAAIdpAAQAAAABAAAAWgAAAAAAAABIAAAAAQAAAEgAAAABAAOgAQADAAAAAQABAACgAgAEAAAAAQAAANKgAwAEAAAAAQAA
AHgAAAAAXdsepgAAAAlwSFlzAAALEwAACxMBAJqcGAAAAVlpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6

View File

@@ -17,30 +17,30 @@ var (
stream = false
req = [2]api.GenerateRequest{
{
Model: smol,
Model: "orca-mini",
Prompt: "why is the ocean blue?",
Stream: &stream,
Options: map[string]any{
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: smol,
Model: "orca-mini",
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
Options: map[string]any{
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
},
}
resp = [2][]string{
{"sunlight", "scattering", "interact"},
{"sunlight"},
{"england", "english", "massachusetts", "pilgrims"},
}
)
func TestIntegrationSimple(t *testing.T) {
func TestIntegrationSimpleOrcaMini(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel()
GenerateTestHelper(ctx, t, req[0], resp[0])

View File

@@ -30,9 +30,9 @@ func TestMaxQueue(t *testing.T) {
t.Setenv("OLLAMA_MAX_QUEUE", strconv.Itoa(threadCount))
req := api.GenerateRequest{
Model: smol,
Model: "orca-mini",
Prompt: "write a long historical fiction story about christopher columbus. use at least 10 facts from his actual journey",
Options: map[string]any{
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
@@ -52,8 +52,8 @@ func TestMaxQueue(t *testing.T) {
embedCtx := ctx
var genwg sync.WaitGroup
genwg.Add(1)
go func() {
genwg.Add(1)
defer genwg.Done()
slog.Info("Starting generate request")
DoGenerate(ctx, t, client, req, resp, 45*time.Second, 5*time.Second)
@@ -61,7 +61,7 @@ func TestMaxQueue(t *testing.T) {
}()
// Give the generate a chance to get started before we start hammering on embed requests
time.Sleep(10 * time.Millisecond)
time.Sleep(5 * time.Millisecond)
threadCount += 10 // Add a few extra to ensure we push the queue past its limit
busyCount := 0
@@ -71,8 +71,8 @@ func TestMaxQueue(t *testing.T) {
counterMu := sync.Mutex{}
var embedwg sync.WaitGroup
for i := 0; i < threadCount; i++ {
embedwg.Add(1)
go func(i int) {
embedwg.Add(1)
defer embedwg.Done()
slog.Info("embed started", "id", i)
embedReq := api.EmbeddingRequest{

View File

@@ -1,195 +0,0 @@
//go:build integration && models
package integration
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"log/slog"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
)
var (
started = time.Now()
chatModels = []string{
"granite3-moe:latest",
"granite-code:latest",
"nemotron-mini:latest",
"command-r:latest",
"gemma2:latest",
"gemma:latest",
"internlm2:latest",
"phi3.5:latest",
"phi3:latest",
// "phi:latest", // flaky, sometimes generates no response on first query
"stablelm2:latest", // Predictions are off, crashes on small VRAM GPUs
"falcon:latest",
"falcon2:latest",
"minicpm-v:latest",
"mistral:latest",
"orca-mini:latest",
"llama2:latest",
"llama3.1:latest",
"llama3.2:latest",
"llama3.2-vision:latest",
"qwen2.5-coder:latest",
"qwen:latest",
"solar-pro:latest",
}
)
func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) {
deadline, hasDeadline := t.Deadline()
if !hasDeadline {
return 8 * time.Minute, 10 * time.Minute
} else if deadline.Compare(time.Now().Add(2*time.Minute)) <= 0 {
t.Skip("too little time")
return time.Duration(0), time.Duration(0)
}
return -time.Since(deadline.Add(-2 * time.Minute)), -time.Since(deadline.Add(-20 * time.Second))
}
func TestModelsGenerate(t *testing.T) {
softTimeout, hardTimeout := getTimeouts(t)
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// TODO use info API eventually
var maxVram uint64
var err error
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
maxVram, err = strconv.ParseUint(s, 10, 64)
if err != nil {
t.Fatalf("invalid OLLAMA_MAX_VRAM %v", err)
}
} else {
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
}
for _, model := range chatModels {
t.Run(model, func(t *testing.T) {
if time.Now().Sub(started) > softTimeout {
t.Skip("skipping remaining tests to avoid excessive runtime")
}
if err := PullIfMissing(ctx, client, model); err != nil {
t.Fatalf("pull failed %s", err)
}
if maxVram > 0 {
resp, err := client.List(ctx)
if err != nil {
t.Fatalf("list models failed %v", err)
}
for _, m := range resp.Models {
if m.Name == model && float32(m.Size)*1.2 > float32(maxVram) {
t.Skipf("model %s is too large for available VRAM: %s > %s", model, format.HumanBytes(m.Size), format.HumanBytes(int64(maxVram)))
}
}
}
// TODO - fiddle with context size
req := api.GenerateRequest{
Model: model,
Prompt: "why is the sky blue?",
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
anyResp := []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}
DoGenerate(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second)
})
}
}
func TestModelsEmbed(t *testing.T) {
softTimeout, hardTimeout := getTimeouts(t)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// TODO use info API eventually
var maxVram uint64
var err error
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
maxVram, err = strconv.ParseUint(s, 10, 64)
if err != nil {
t.Fatalf("invalid OLLAMA_MAX_VRAM %v", err)
}
} else {
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
}
data, err := ioutil.ReadFile(filepath.Join("testdata", "embed.json"))
if err != nil {
t.Fatalf("failed to open test data file: %s", err)
}
testCase := map[string][]float64{}
err = json.Unmarshal(data, &testCase)
if err != nil {
t.Fatalf("failed to load test data: %s", err)
}
for model, expected := range testCase {
t.Run(model, func(t *testing.T) {
if time.Now().Sub(started) > softTimeout {
t.Skip("skipping remaining tests to avoid excessive runtime")
}
if err := PullIfMissing(ctx, client, model); err != nil {
t.Fatalf("pull failed %s", err)
}
if maxVram > 0 {
resp, err := client.List(ctx)
if err != nil {
t.Fatalf("list models failed %v", err)
}
for _, m := range resp.Models {
if m.Name == model && float32(m.Size)*1.2 > float32(maxVram) {
t.Skipf("model %s is too large for available VRAM: %s > %s", model, format.HumanBytes(m.Size), format.HumanBytes(int64(maxVram)))
}
}
}
req := api.EmbeddingRequest{
Model: model,
Prompt: "why is the sky blue?",
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
resp, err := client.Embeddings(ctx, &req)
if err != nil {
t.Fatalf("embeddings call failed %s", err)
}
if len(resp.Embedding) == 0 {
t.Errorf("zero length embedding response")
}
if len(expected) != len(resp.Embedding) {
expStr := make([]string, len(resp.Embedding))
for i, v := range resp.Embedding {
expStr[i] = fmt.Sprintf("%0.6f", v)
}
// When adding new models, use this output to populate the testdata/embed.json
fmt.Printf("expected\n%s\n", strings.Join(expStr, ", "))
t.Fatalf("expected %d, got %d", len(expected), len(resp.Embedding))
}
sim := cosineSimilarity(resp.Embedding, expected)
if sim < 0.99 {
t.Fatalf("expected %v, got %v (similarity: %f)", expected[0:5], resp.Embedding[0:5], sim)
}
})
}
}

File diff suppressed because one or more lines are too long

View File

@@ -24,14 +24,9 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle"
"github.com/ollama/ollama/format"
"github.com/stretchr/testify/require"
)
const (
smol = "llama3.2:1b"
)
func Init() {
lifecycle.InitLogging()
}
@@ -145,7 +140,7 @@ func PullIfMissing(ctx context.Context, client *api.Client, modelName string) er
showCtx, cancel := context.WithDeadlineCause(
ctx,
time.Now().Add(20*time.Second),
time.Now().Add(10*time.Second),
fmt.Errorf("show for existing model %s took too long", modelName),
)
defer cancel()
@@ -162,7 +157,7 @@ func PullIfMissing(ctx context.Context, client *api.Client, modelName string) er
}
slog.Info("model missing", "model", modelName)
stallDuration := 60 * time.Second // This includes checksum verification, which can take a while on larger models, and slower systems
stallDuration := 30 * time.Second // This includes checksum verification, which can take a while on larger models
stallTimer := time.NewTimer(stallDuration)
fn := func(resp api.ProgressResponse) error {
// fmt.Print(".")
@@ -288,51 +283,51 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
}
// Generate a set of requests
// By default each request uses llama3.2 as the model
// By default each request uses orca-mini as the model
func GenerateRequests() ([]api.GenerateRequest, [][]string) {
return []api.GenerateRequest{
{
Model: smol,
Model: "orca-mini",
Prompt: "why is the ocean blue?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: smol,
Model: "orca-mini",
Prompt: "why is the color of dirt brown?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: smol,
Model: "orca-mini",
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: smol,
Model: "orca-mini",
Prompt: "what is the origin of independence day?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: smol,
Model: "orca-mini",
Prompt: "what is the composition of air?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
@@ -346,15 +341,3 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
{"nitrogen", "oxygen", "carbon", "dioxide"},
}
}
func skipUnderMinVRAM(t *testing.T, gb uint64) {
// TODO use info API in the future
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
maxVram, err := strconv.ParseUint(s, 10, 64)
require.NoError(t, err)
// Don't hammer on small VRAM cards...
if maxVram < gb*format.GibiByte {
t.Skip("skipping with small VRAM to avoid timeouts")
}
}
}

View File

@@ -1,77 +0,0 @@
package kvcache
import (
"errors"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
var (
ErrKvCacheFull = errors.New("could not find a kv cache slot")
ErrNotSupported = errors.New("model does not support operation")
)
type Cache interface {
// ** used by model implementations **
// SetLayer sets the active layer of the cache
SetLayer(layer int)
// Get returns the history of key and value tensors plus a mask
//
// The shape of the tensors is documented in the specific
// cache implementation used.
Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
// Put stores a batch of key and value in the cache
//
// The shape of the tensors is documented in the specific
// cache implementation used.
Put(ctx ml.Context, key, value ml.Tensor)
// SetConfig controls optimizations (mostly backend-specific) that may transform
// the output of the cache to work better with specific kernels. If not called,
// the backend settings will be used. This works well when calling Attention.
//
// The config can be overridden by models, especially if they require vanilla
// output when implementing their own version of attention. To do this, pass
// an empty ml.CacheConfig.
//
// Most models will not need to use this.
SetConfig(ml.CacheConfig)
// ** cache management **
// Init sets up runtime parameters.
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
// dtype: The data type for storing cache entries
// maxSequences: The maximum number of sequences stored in the cache - across all batches
// capacity: The number of cache entries to store, per sequence
// maxBatch: The maximum number of tokens that can occur in a single batch
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
// Close closes the cache and frees resources associated with it
Close()
// StartForward is called before the start of the model's forward pass.
// For each token in the coming batch, there must be a corresponding
// entry in positions and seqs. reserve is to preallocate memory
// without actually storing data in the cache.
StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix(srcSeq, dstSeq int, len int32)
// CanResume returns true if the cache can continue with the next token at
// the given position and sequence. Assumes that the caller has already
// verified the contents of the cache.
CanResume(seq int, pos int32) bool
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
//
// If an error occurs, the entire context for the sequence should be
// removed by calling Remove(seq, 0, math.MaxInt32)
Remove(seq int, beginIndex, endIndex int32) error
}

View File

@@ -1,739 +0,0 @@
package kvcache
import (
"errors"
"fmt"
"log/slog"
"math"
"slices"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
// Causal cache stores K and V tensors according to their position in the
// sequence. Returns the history and a mask for attending to past tokens
//
// The tensors are of shape embed dim, kv heads, batch size
// The mask is of shape history size, batch size
type Causal struct {
DType ml.DType
windowSize int32
chunkSize int32
opts CausalOptions
// config controls mostly backend-specific optimizations
config *ml.CacheConfig
// ** current forward pass **
// the active layer for Get and Put
curLayer int
// starting location for data storage for this batch
curLoc int
// size of the current batch
curBatchSize int
// mask of the cache as used by this batch
curMask ml.Tensor
// locations in the cache that are needed for this batch
curCellRange cellRange
// curSequences is the sequences corresponding to this pass's entries in the cache
curSequences []int
// curPositions is the positions corresponding to this pass's entries in the cache
curPositions []int32
// ** cache metadata **
// for each possible location in the cache, stores the position and set of sequences
// that reference the data there
cells []cacheCell
// maps from sequence to the range of locations where it is stored in the cache
cellRanges map[int]cellRange
// ** cache data storage **
shiftFn shiftFn
backend ml.Backend
ctxs map[int]ml.Context
keys, values map[int]ml.Tensor
}
type cacheCell struct {
pos int32
sequences []int
}
type cellRange struct {
min int
max int
}
func NewCausalCache(shift shiftFn) *Causal {
return &Causal{
windowSize: math.MaxInt32,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
}
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
return &Causal{
windowSize: windowSize,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
}
func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
return &Causal{
windowSize: math.MaxInt32,
chunkSize: chunkSize,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
}
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
if c.config == nil {
var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok {
config = cc.CacheConfig()
}
c.config = &config
}
if c.config.CachePadding == 0 {
c.config.CachePadding = 1
}
if c.config.MaskBatchPadding == 0 {
c.config.MaskBatchPadding = 1
}
if c.config.MaskDType == ml.DTypeOther {
c.config.MaskDType = ml.DTypeF32
}
var cacheSize int
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize) {
cacheSize = maxSequences * capacity
} else {
cacheSize = (maxSequences * int(c.windowSize)) + maxBatch
}
cacheSize = roundUp(cacheSize, c.config.CachePadding)
c.cells = make([]cacheCell, cacheSize)
c.DType = dtype
c.cellRanges = make(map[int]cellRange)
c.backend = backend
}
func (c *Causal) SetConfig(config ml.CacheConfig) {
if c.config != nil {
panic("config cannot be changed after being previously set, either by the model or backend")
}
c.config = &config
}
func (c *Causal) Close() {
for _, ctx := range c.ctxs {
ctx.Close()
}
}
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
c.curBatchSize = len(batch.Positions)
c.curSequences = batch.Sequences
c.curPositions = batch.Positions
c.opts.Except = nil
if !reserve {
c.updateSlidingWindow()
var err error
c.curLoc, err = c.findStartLoc()
if errors.Is(err, ErrKvCacheFull) {
c.defrag()
c.curLoc, err = c.findStartLoc()
}
if err != nil {
return err
}
c.curCellRange = newRange()
for i, pos := range batch.Positions {
seq := batch.Sequences[i]
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
seqRange, ok := c.cellRanges[seq]
if !ok {
seqRange = newRange()
}
if c.curLoc+i > seqRange.max {
seqRange.max = c.curLoc + i
}
if seqRange.max > c.curCellRange.max {
c.curCellRange.max = seqRange.max
}
if c.curLoc+i < seqRange.min {
seqRange.min = c.curLoc + i
}
if seqRange.min < c.curCellRange.min {
c.curCellRange.min = seqRange.min
}
c.cellRanges[seq] = seqRange
}
} else {
// If we are reserving memory, don't update any of the cache metadata but set the size
// to the worst case.
c.curLoc = 0
c.curCellRange.min = 0
c.curCellRange.max = len(c.cells) - 1
}
var err error
c.curMask, err = c.buildMask(ctx)
return err
}
func newRange() cellRange {
return cellRange{
min: math.MaxInt,
max: 0,
}
}
// Find the first contiguous block of at least curBatchSize
func (c *Causal) findStartLoc() (int, error) {
var start, count int
for i := range c.cells {
if len(c.cells[i].sequences) == 0 {
count++
if count >= c.curBatchSize {
return start, nil
}
} else {
start = i + 1
count = 0
}
}
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells))
}
func (c *Causal) updateSlidingWindow() {
if c.windowSize == math.MaxInt32 {
return
}
// create a map of unique sequences to the lowest position in that sequence
lowestPos := make(map[int]int32)
for i := range c.curPositions {
seq := c.curSequences[i]
pos, ok := lowestPos[seq]
if !ok {
pos = c.curPositions[i]
} else if c.curPositions[i] < pos {
pos = c.curPositions[i]
}
lowestPos[seq] = pos
}
// delete any entries that are beyond the window of the oldest position in the sequence
for seq, pos := range lowestPos {
oldRange, ok := c.cellRanges[seq]
if !ok {
continue
}
newRange := newRange()
for i := oldRange.min; i <= oldRange.max; i++ {
if slices.Contains(c.cells[i].sequences, seq) {
if c.cells[i].pos < pos-c.windowSize {
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
} else {
newRange.min = min(newRange.min, i)
newRange.max = max(newRange.max, i)
}
}
}
c.cellRanges[seq] = newRange
}
}
func roundDown(length, pad int) int {
return (length / pad) * pad
}
func roundUp(length, pad int) int {
return ((length + pad - 1) / pad) * pad
}
// Builds a mask of history x batch indicating whether for each token in the batch the
// token in the history should apply. This is based on both the sequence and causality (the
// position of the history is not ahead of the token in the batch).
func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
// Align and pad the two dimensions as required by the backend
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
length := c.curCellRange.max - c.curCellRange.min + 1
mask := make([]float32, batchSize*length)
for i := range c.curBatchSize {
enabled := !slices.Contains(c.opts.Except, i)
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
(enabled && c.cells[j].pos > c.curPositions[i]) ||
c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
c.cells[j].pos < c.curPositions[i]-c.windowSize {
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
}
}
}
// Mask out any padding tokens we added. For padding that we added to the cache history, this
// has already been masked out because the sequence doesn't match.
for i := c.curBatchSize * length; i < len(mask); i++ {
mask[i] = float32(math.Inf(-1))
}
maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize)
if err != nil {
return nil, err
}
if c.config.MaskDType != ml.DTypeF32 {
out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
ctx.Forward(maskTensor.Copy(ctx, out))
maskTensor = out
}
return maskTensor, nil
}
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
for i, key := range c.keys {
if key == nil {
continue
}
kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
value := c.values[i]
var vSrcView, vDstView ml.Tensor
if c.config.PermutedV {
vHeadDim := value.Dim(1)
elemSize := value.Stride(0)
vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
}
ctx.Forward(
kSrcView.Copy(ctx, kDstView),
vSrcView.Copy(ctx, vDstView),
)
}
}
func (c *Causal) defrag() {
slog.Debug("defragmenting kv cache")
// Defrag strategy:
// - Search for empty holes at the beginning of the cache,
// filling them with active data starting at the end
// - If there are contiguous elements that need to be moved,
// combine them into a single operation by holding new moves
// until we see that the next one is non-contiguous
// - Fill up the context with the maximum number of operations it
// can hold then compute that and continue with a new context
//
// We could try to optimize placement by grouping blocks from
// the same sequences together but most likely the next forward
// pass will disrupt this anyways, so the real world benefit
// seems limited as this time.
ctx := c.backend.NewContext()
// For every move, 6 tensors are required per layer (2 views and a
// copy for each of k and v). We also need to refer to the original
// k and v cache tensors - once per layer, not per move.
layers := 0
for _, key := range c.keys {
if key == nil {
continue
}
layers++
}
maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
moves := 0
var pendingSrc, pendingDst, pendingLen int
src := len(c.cells) - 1
for dst := 0; dst < src; dst++ {
if len(c.cells[dst].sequences) == 0 {
for ; src > dst; src-- {
if len(c.cells[src].sequences) != 0 {
c.cells[dst] = c.cells[src]
c.cells[src] = cacheCell{}
if pendingLen > 0 {
if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen {
pendingSrc = src
pendingLen++
break
} else {
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
moves++
}
}
pendingSrc = src
pendingDst = dst
pendingLen = 1
break
}
}
}
if moves >= maxMoves {
ctx.Compute()
ctx.Close()
ctx = c.backend.NewContext()
moves = 0
}
}
if pendingLen > 0 {
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
moves++
}
if moves > 0 {
ctx.Compute()
}
ctx.Close()
// Reset range metadata
for seq := range c.cellRanges {
seqRange := newRange()
for i, cell := range c.cells {
if slices.Contains(cell.sequences, seq) {
if i < seqRange.min {
seqRange.min = i
}
if i > seqRange.max {
seqRange.max = i
}
}
}
c.cellRanges[seq] = seqRange
}
}
func (c *Causal) SetLayer(layer int) {
c.curLayer = layer
}
type CausalOptions struct {
// Enabled controls whether the causal mask is generated for a particular index in a batch
Except []int
}
// SetCausal disables causal mask generation for a particular range of indicies in
// the current batch for subsequent calls to Get. The state resets for the next forward pass.
func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
if !slices.Equal(c.opts.Except, opts.Except) {
c.opts = opts
if ctx != nil {
var err error
c.curMask, err = c.buildMask(ctx)
if err != nil {
// This error should never occur because we have previously built a mask with the same shape
panic(fmt.Errorf("SetCausal: %w", err))
}
}
}
}
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
key := c.keys[c.curLayer]
value := c.values[c.curLayer]
kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
cachedSize := c.curMask.Dim(0)
key = key.View(ctx, rowSize*c.curCellRange.min,
kHeadDim, key.Stride(1),
numKVHeads, key.Stride(2),
cachedSize,
)
if c.config.PermutedV {
vHeadDim := value.Dim(1)
elemSize := value.Stride(0)
value = value.View(ctx, elemSize*c.curCellRange.min,
cachedSize, value.Stride(1),
vHeadDim, value.Stride(2),
numKVHeads,
)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
value = value.View(ctx, rowSize*c.curCellRange.min,
vHeadDim, value.Stride(1),
numKVHeads, value.Stride(2),
cachedSize,
)
}
return key, value, c.curMask
}
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
kHeadDim := key.Dim(0)
vHeadDim := value.Dim(0)
numKVHeads := key.Dim(1)
batchSize := key.Dim(2)
if c.curBatchSize != batchSize {
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
}
if _, ok := c.ctxs[c.curLayer]; !ok {
c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
}
if _, ok := c.keys[c.curLayer]; !ok {
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, len(c.cells))
}
if _, ok := c.values[c.curLayer]; !ok {
if c.config.PermutedV {
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vHeadDim, numKVHeads)
} else {
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, len(c.cells))
}
}
rowSize := c.keys[c.curLayer].Stride(2)
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
if c.config.PermutedV {
elemSize := c.values[c.curLayer].Stride(0)
value = value.Permute(ctx, 1, 2, 0, 3)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
} else {
rowSize := c.values[c.curLayer].Stride(2)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
}
}
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
seqRange := newRange()
for i := range c.cells {
// Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end
if slices.Contains(c.cells[i].sequences, dstSeq) {
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq })
}
if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len {
c.cells[i].sequences = append(c.cells[i].sequences, dstSeq)
if i < seqRange.min {
seqRange.min = i
}
if i > seqRange.max {
seqRange.max = i
}
}
}
c.cellRanges[dstSeq] = seqRange
}
func (c *Causal) CanResume(seq int, pos int32) bool {
if c.windowSize == math.MaxInt32 {
return true
}
seqRange, ok := c.cellRanges[seq]
if !ok {
return false
}
// for sliding window, check that the window of the new sequence is contained in
// the window of what we are storing
var last int32 = -1
for i := seqRange.min; i <= seqRange.max; i++ {
if slices.Contains(c.cells[i].sequences, seq) {
last = max(last, c.cells[i].pos)
}
}
if last == -1 {
return false
}
lastWindowStart := max(0, last-c.windowSize)
posWindowStart := max(0, pos-c.windowSize)
return posWindowStart >= lastWindowStart
}
func (c *Causal) shift(seq int, beginIndex, offset int32) error {
if c.shiftFn == nil {
return ErrNotSupported
}
ctx := c.backend.NewContext()
defer ctx.Close()
seqRange := c.cellRanges[seq]
size := seqRange.max - seqRange.min + 1
offsets := make([]int32, size)
for i := range offsets {
cell := c.cells[seqRange.min+i]
if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
offsets[i] = offset
}
}
kShift, err := ctx.Input().FromIntSlice(offsets, len(offsets))
if err != nil {
return err
}
for i, key := range c.keys {
if key == nil {
continue
}
kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
key = key.View(ctx, rowSize*seqRange.min,
kHeadDim, key.Stride(1),
numKVHeads, key.Stride(2),
size,
)
roped, err := c.shiftFn(ctx, i, key, kShift)
if err != nil {
return err
}
ctx.Forward(roped.Copy(ctx, key))
}
ctx.Compute()
return nil
}
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
// TODO(jessegross): We should check to see if removing the middle of the sequence will
// cause the sliding window to encompass tokens that we no longer have. If so, then we
// should return an error, which will trigger the runner to evaluate the full history and
// rebuild the window. However, if we have multimodal inputs in our history, this reuse
// results in use after free, so we don't do it for now.
var offset int32
if endIndex != math.MaxInt32 {
offset = beginIndex - endIndex
}
seqRange := newRange()
for i := range c.cells {
if slices.Contains(c.cells[i].sequences, seq) {
if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex {
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
} else {
if c.cells[i].pos >= endIndex {
if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
return errors.New("shifting cells shared by multiple sequences not supported")
}
c.cells[i].pos += offset
}
if i < seqRange.min {
seqRange.min = i
}
if i > seqRange.max {
seqRange.max = i
}
}
}
}
if seqRange == newRange() {
delete(c.cellRanges, seq)
return nil
}
c.cellRanges[seq] = seqRange
if endIndex != math.MaxInt32 {
err := c.shift(seq, endIndex+offset, offset)
if err != nil {
return err
}
}
return nil
}

View File

@@ -1,598 +0,0 @@
package kvcache
import (
"math"
"slices"
"testing"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
type testCase struct {
name string
in []float32
inShape []int
seqs []int
pos []int32
expected []float32
expectedShape []int
expectedMask []float32
}
func TestStore(t *testing.T) {
backend := &testBackend{}
cache := NewCausalCache(nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
name: "FirstBatch",
in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
inShape: []int{2, 3, 4},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
expectedShape: []int{2, 3, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
},
{
name: "SecondBatch",
in: []float32{115, 215, 125, 225, 135, 235},
inShape: []int{2, 3, 1},
seqs: []int{0},
pos: []int32{4},
expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
expectedShape: []int{2, 3, 5},
expectedMask: []float32{0, 0, 0, 0, 0},
},
}
testCache(t, backend, cache, tests)
}
func TestSWA(t *testing.T) {
backend := &testBackend{}
cache := NewSWACache(1, nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
},
{
name: "SecondBatch",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{4, 5},
expected: []float32{5, 6, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))},
},
}
testCache(t, backend, cache, tests)
}
func TestChunkedAttention(t *testing.T) {
cache := NewChunkedAttentionCache(2, nil)
defer cache.Close()
var b testBackend
cache.Init(&b, ml.DTypeF16, 1, 16, 16)
x := float32(math.Inf(-1))
testCache(
t, &b, cache,
[]testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{
0, x, x, x,
0, 0, x, x,
x, x, 0, x,
x, x, 0, 0,
},
},
{
name: "SecondBatch",
in: []float32{5, 6, 7},
inShape: []int{1, 1, 3},
seqs: []int{0, 0, 0},
pos: []int32{4, 5, 6},
expected: []float32{1, 2, 3, 4, 5, 6, 7},
expectedShape: []int{1, 1, 7},
expectedMask: []float32{
x, x, x, x, 0, x, x,
x, x, x, x, 0, 0, x,
x, x, x, x, x, x, 0,
},
},
{
name: "ThirdBatch",
in: []float32{8, 9},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{7, 8},
expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
expectedShape: []int{1, 1, 9},
expectedMask: []float32{
x, x, x, x, x, x, 0, 0, x,
x, x, x, x, x, x, x, x, 0,
},
},
},
)
}
func TestSequences(t *testing.T) {
backend := &testBackend{}
cache := NewCausalCache(nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 1, 1},
pos: []int32{0, 1, 0, 1},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
},
{
name: "SecondBatch",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{0, 1},
pos: []int32{2, 2},
expected: []float32{1, 2, 3, 4, 5, 6},
expectedShape: []int{1, 1, 6},
expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
},
}
testCache(t, backend, cache, tests)
}
func TestRemove(t *testing.T) {
backend := &testBackend{}
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.Add(ctx, shift), nil
})
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 1, 1},
pos: []int32{0, 1, 0, 1},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
},
}
testCache(t, backend, cache, tests)
err := cache.Remove(0, 1, math.MaxInt32)
if err != nil {
panic(err)
}
tests = []testCase{
{
name: "RemoveEnd",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{0, 1},
pos: []int32{1, 2},
expected: []float32{1, 2, 3, 4, 5, 6},
expectedShape: []int{1, 1, 6},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
},
}
testCache(t, backend, cache, tests)
err = cache.Remove(0, 0, 1)
if err != nil {
panic(err)
}
tests = []testCase{
{
name: "RemoveMiddle",
in: []float32{7, 8},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{1, 2},
expected: []float32{7, 8, 3, 4, 4},
expectedShape: []int{1, 1, 5},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0},
},
}
testCache(t, backend, cache, tests)
}
func TestDefrag(t *testing.T) {
backend := &testBackend{}
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.Add(ctx, shift), nil
})
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
inShape: []int{1, 1, 16},
seqs: []int{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
pos: []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
expectedShape: []int{1, 1, 16},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
}
testCache(t, backend, cache, tests)
err := cache.Remove(0, 2, 4)
if err != nil {
panic(err)
}
err = cache.Remove(0, 13, math.MaxInt32)
if err != nil {
panic(err)
}
tests = []testCase{
{
name: "Defrag",
in: []float32{17, 18, 19},
inShape: []int{1, 1, 3},
seqs: []int{0, 0, 0},
pos: []int32{16, 17, 18},
expected: []float32{1, 2, 12, 13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 17, 18, 19},
expectedShape: []int{1, 1, 16},
expectedMask: []float32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
}
testCache(t, backend, cache, tests)
}
func TestCopy(t *testing.T) {
backend := &testBackend{}
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
},
}
testCache(t, backend, cache, tests)
cache.CopyPrefix(0, 1, 2)
tests = []testCase{
{
name: "Copy",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{1, 1},
pos: []int32{3, 4},
expected: []float32{1, 2, 3, 4, 5, 6},
expectedShape: []int{1, 1, 6},
expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
},
}
testCache(t, backend, cache, tests)
}
func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
context := backend.NewContext()
defer context.Close()
err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false)
if err != nil {
panic(err)
}
cache.SetLayer(0)
tensor, _ := context.FromFloatSlice(test.in, test.inShape...)
cache.Put(context, tensor, tensor)
out, _, mask := cache.Get(context)
context.Forward(out, mask).Compute(out, mask)
if !slices.Equal(out.Floats(), test.expected) {
t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected)
}
if !slices.Equal(out.Shape(), test.expectedShape) {
t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape)
}
if !slices.Equal(mask.Floats(), test.expectedMask) {
t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask)
}
})
}
}
func TestCanResume(t *testing.T) {
backend := &testBackend{}
windowSize := int32(4)
cache := NewSWACache(windowSize, nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
context := backend.NewContext()
defer context.Close()
err := cache.StartForward(context, input.Batch{
Positions: []int32{0, 1, 2, 3},
Sequences: []int{0, 0, 0, 0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor, _ := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
cache.Put(context, tensor, tensor)
// with window size 4, nothing has slid out of the window yet
if !cache.CanResume(0, 0) {
t.Errorf("CanResume(0, 0) = false, want true (within window)")
}
if !cache.CanResume(0, 1) {
t.Errorf("CanResume(0, 1) = false, want true (within window)")
}
if !cache.CanResume(0, 2) {
t.Errorf("CanResume(0, 2) = false, want true (within window)")
}
if !cache.CanResume(0, 3) {
t.Errorf("CanResume(0, 3) = false, want true (latest position)")
}
// shift window by adding position 4
err = cache.StartForward(context, input.Batch{
Positions: []int32{4, 5},
Sequences: []int{0, 0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor, _ = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
cache.Put(context, tensor, tensor)
// only the latest position has overlapping windows
if cache.CanResume(0, 0) {
t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
}
if cache.CanResume(0, 1) {
t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
}
if cache.CanResume(0, 2) {
t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
}
if cache.CanResume(0, 3) {
t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
}
if cache.CanResume(0, 4) {
t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
}
if !cache.CanResume(0, 5) {
t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
}
}
type testBackend struct {
ml.Backend
}
func (b *testBackend) NewContext() ml.Context {
return &testContext{}
}
func (b *testBackend) NewContextSize(int) ml.Context {
return &testContext{}
}
type testContext struct {
ml.Context
}
func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
total := 0
if len(shape) > 0 {
total = 1
for _, s := range shape {
total *= s
}
}
return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
}
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return c.Empty(dtype, shape...)
}
func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
copy(t.data, s)
return t, nil
}
func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
f := make([]float32, len(s))
for i := range f {
f[i] = float32(s[i])
}
out, _ := c.FromFloatSlice(f, shape...)
out.(*testTensor).dtype = ml.DTypeI32
return out, nil
}
func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
s := make([]float32, 0, int((stop-start)/step))
for i := start; i < stop; i += step {
s = append(s, i)
}
out, _ := c.FromFloatSlice(s, len(s))
out.(*testTensor).dtype = dtype
return out
}
func (c *testContext) Input() ml.Context { return c }
func (c *testContext) Layer(int) ml.Context { return c }
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
func (c *testContext) Compute(...ml.Tensor) {}
func (c *testContext) Reserve() error { return nil }
func (c *testContext) MaxGraphNodes() int {
return 10
}
func (c *testContext) Close() {}
type testTensor struct {
ml.Tensor
dtype ml.DType
elementSize int
data []float32
shape []int
}
func (t *testTensor) Dim(n int) int {
return t.shape[n]
}
func (t *testTensor) Stride(n int) int {
stride := t.elementSize
for i := range n {
stride *= t.shape[i]
}
return stride
}
func (t *testTensor) Shape() []int {
return t.shape
}
func (t *testTensor) DType() ml.DType {
return t.dtype
}
func (t *testTensor) Floats() []float32 {
out := make([]float32, len(t.data))
copy(out, t.data)
return out
}
func (t *testTensor) Neg(ctx ml.Context) ml.Tensor {
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
for i := range out.data {
out.data[i] = -t.data[i]
}
return out
}
func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
for i := range out.data {
out.data[i] = t.data[i] + t2.(*testTensor).data[i]
}
return out
}
func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
offset /= t.elementSize
var s []int
switch len(shape) {
case 1:
s = []int{shape[0]}
case 5:
s = []int{shape[0], shape[2], shape[4]}
default:
panic("unsupported number of dimensions")
}
context := &testContext{}
view := context.Empty(t.dtype, s...).(*testTensor)
view.data = t.data[offset : offset+len(view.data)]
return view
}
func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
copy(t2.(*testTensor).data, t.data)
return nil
}

View File

@@ -1,156 +0,0 @@
package kvcache
import (
"fmt"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
// Encoder cache stores K and V tensors that are position independent
//
// The tensors can be of any shape and will be returned as they were stored
// The mask is currently always nil
//
// Not currently safe for multiple sequences
type EncoderCache struct {
// config controls mostly backend-specific optimizations
config *ml.CacheConfig
// ** current forward pass **
// the active layer for Get and Put
curLayer int
// if something is stored during this pass, this
// will be the position (but there is no guarantee
// anything will be stored)
curPos int32
// curReserve indicates that this forward pass is only for
// memory reservation and we should not update our metadata
// based on it.
curReserve bool
// ** cache metadata **
// was something stored in the cache?
encoderCached bool
// position of the cached data
encoderPos int32
// ** cache data storage **
backend ml.Backend
ctxs map[int]ml.Context
keys, values map[int]ml.Tensor
}
func NewEncoderCache() *EncoderCache {
return &EncoderCache{
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
}
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
if c.config == nil {
var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok {
config = cc.CacheConfig()
}
c.config = &config
}
if maxSequences > 1 {
panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
}
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
}
c.backend = backend
}
func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
if c.config != nil {
panic("config cannot be changed after being previously set, either by the model or backend")
}
c.config = &config
}
func (c *EncoderCache) Close() {
for _, ctx := range c.ctxs {
ctx.Close()
}
}
func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
// We work with the most recent image
if len(batch.Multimodal) > 0 {
c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
}
c.curReserve = reserve
return nil
}
func (c *EncoderCache) SetLayer(layer int) {
c.curLayer = layer
}
func (c *EncoderCache) EncoderCached() bool {
return c.encoderCached
}
func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
return c.keys[c.curLayer], c.values[c.curLayer], nil
}
func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
if !c.curReserve {
c.encoderPos = c.curPos
c.encoderCached = true
}
if c.config.PermutedV {
value = value.Permute(ctx, 1, 2, 0, 3)
}
if _, ok := c.ctxs[c.curLayer]; !ok {
c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
}
if _, ok := c.keys[c.curLayer]; !ok {
c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
}
if _, ok := c.values[c.curLayer]; !ok {
c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
}
ctx.Forward(
key.Copy(ctx, c.keys[c.curLayer]),
value.Copy(ctx, c.values[c.curLayer]),
)
}
func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
panic("encoder cache does not support multiple sequences")
}
func (c *EncoderCache) CanResume(seq int, pos int32) bool {
return true
}
func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
c.encoderCached = false
}
return nil
}

View File

@@ -1,110 +0,0 @@
package kvcache
import (
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
// Wrapper cache is a container for multiple types of caches,
// such as for the encoding and decoding portions of a model.
type WrapperCache struct {
// caches we are wrapping
caches []Cache
// cache to be used for this layer
curType int
}
func NewWrapperCache(caches ...Cache) *WrapperCache {
return &WrapperCache{
caches: caches,
}
}
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
for _, cache := range c.caches {
cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
}
}
func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
for _, cache := range c.caches {
cache.SetConfig(config)
}
}
func (c *WrapperCache) Close() {
for _, cache := range c.caches {
cache.Close()
}
}
func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
for i, cache := range c.caches {
err := cache.StartForward(ctx, batch, reserve)
if err != nil {
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
for j := i - 1; j >= 0; j-- {
for k := range batch.Positions {
_ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
}
}
return err
}
}
c.curType = 0
return nil
}
func (c *WrapperCache) SetLayer(layer int) {
for _, cache := range c.caches {
cache.SetLayer(layer)
}
}
func (c *WrapperCache) SetLayerType(layerType int) {
c.curType = layerType
}
func (c *WrapperCache) UnderlyingCache() Cache {
return c.caches[c.curType]
}
func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
return c.caches[c.curType].Get(ctx)
}
func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.caches[c.curType].Put(ctx, key, value)
}
func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
for _, cache := range c.caches {
cache.CopyPrefix(srcSeq, dstSeq, len)
}
}
func (c *WrapperCache) CanResume(seq int, pos int32) bool {
for _, cache := range c.caches {
if !cache.CanResume(seq, pos) {
return false
}
}
return true
}
func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
// If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
for _, cache := range c.caches {
err := cache.Remove(seq, beginIndex, endIndex)
if err != nil {
return err
}
}
return nil
}

View File

@@ -1,52 +1,156 @@
# `llama`
This package provides Go bindings to [llama.cpp](https://github.com/ggerganov/llama.cpp).
This package integrates the [llama.cpp](https://github.com/ggerganov/llama.cpp) library as a Go package and makes it easy to build it with tags for different CPU and GPU processors.
Supported:
- [x] CPU
- [x] avx, avx2
- [x] macOS Metal
- [x] Windows CUDA
- [x] Windows ROCm
- [x] Linux CUDA
- [x] Linux ROCm
- [x] Llava
Extra build steps are required for CUDA and ROCm on Windows since `nvcc` and `hipcc` both require using msvc as the host compiler. For these shared libraries are created:
- `ggml_cuda.dll` on Windows or `ggml_cuda.so` on Linux
- `ggml_hipblas.dll` on Windows or `ggml_hipblas.so` on Linux
> Note: it's important that memory is allocated and freed by the same compiler (e.g. entirely by code compiled with msvc or mingw). Issues from this should be rare, but there are some places where pointers are returned by the CUDA or HIP runtimes and freed elsewhere, causing a a crash. In a future change the same runtime should be used in both cases to avoid crashes.
## Building
```
go build .
```
### AVX
```shell
go build -tags avx .
```
### AVX2
```shell
# go doesn't recognize `-mfma` as a valid compiler flag
# see https://github.com/golang/go/issues/17895
go env -w "CGO_CPPFLAGS_ALLOW=-mfma|-mf16c"
go build -tags=avx,avx2 .
```
## Linux
### CUDA
Install the [CUDA toolkit v11.3.1](https://developer.nvidia.com/cuda-11-3-1-download-archive):
```shell
make ggml_cuda.so
go build -tags avx,cuda .
```
### ROCm
Install [ROCm](https://rocm.docs.amd.com/en/latest/).
```shell
make ggml_hipblas.so
go build -tags avx,rocm .
```
## Windows
Download [w64devkit](https://github.com/skeeto/w64devkit/releases/latest) for a simple MinGW development environment.
### CUDA
Install the [CUDA toolkit v11.3.1](https://developer.nvidia.com/cuda-11-3-1-download-archive) then build the cuda code:
```shell
make ggml_cuda.dll
go build -tags avx,cuda .
```
### ROCm
Install [ROCm](https://rocm.docs.amd.com/en/latest/).
```shell
make ggml_hipblas.dll
go build -tags avx,rocm .
```
## Building runners
```shell
# build all runners for this platform
make -j
```
## Vendoring
Ollama vendors [llama.cpp](https://github.com/ggerganov/llama.cpp/) and [ggml](https://github.com/ggerganov/llama.cpp/tree/master/ggml/src). While we generally strive to contribute changes back upstream to avoid drift, we carry a small set of patches which are applied to the tracking commit.
Ollama currently vendors [llama.cpp](https://github.com/ggerganov/llama.cpp/) and [ggml](https://github.com/ggerganov/ggml) through a vendoring model. While we generally strive to contribute changes back upstream to avoid drift, we cary a small set of patches which are applied to the tracking commit. A set of make targets are available to aid developers in updating to a newer tracking commit, or to work on changes.
If you update the vendoring code, start by running the following command to establish the tracking llama.cpp repo in the `./vendor/` directory.
```shell
make -f Makefile.sync apply-patches
```
make apply-patches
```
### Updating Base Commit
**Pin to new base commit**
To change the base commit, update `FETCH_HEAD` in Makefile.sync.
To update to a newer base commit, select the upstream git tag or commit and update `llama/vendoring`
#### Applying patches
When updating to a newer base commit, the existing patches may not apply cleanly and require manual merge resolution.
Start by applying the patches. If any of the patches have conflicts, the `git am` will stop at the first failure.
```shell
make -f Makefile.sync apply-patches
```
make apply-patches
```
If there are conflicts, you will see an error message. Resolve the conflicts in `./vendor/`, and continue the patch series with `git am --continue` and rerun `make -f Makefile.sync apply-patches`. Repeat until all patches are successfully applied.
If you see an error message about a conflict, go into the `./vendor/` directory, and perform merge resolution using your preferred tool to the patch commit which failed. Save the file(s) and continue the patch series with `git am --continue` . If any additional patches fail, follow the same pattern until the full patch series is applied. Once finished, run a final `create-patches` and `sync` target to ensure everything is updated.
Once all patches are applied, commit the changes to the tracking repository.
```shell
make -f Makefile.sync format-patches sync
```
make create-patches sync
```
Build and test Ollama, and make any necessary changes to the Go code based on the new base commit. Submit your PR to the Ollama repo.
### Generating Patches
When working on new fixes or features that impact vendored code, use the following model. First get a clean tracking repo with all current patches applied:
```shell
make -f Makefile.sync clean apply-patches
```
make apply-patches
```
Now edit the upstream native code in the `./vendor/` directory. You do not need to commit every change in order to build, a dirty working tree in the tracking repo is OK while developing. Simply save in your editor, and run the following to refresh the vendored code with your changes, build the backend(s) and build ollama:
```
make sync
make -j 8
go build .
```
> [!IMPORTANT]
> Do **NOT** run `apply-patches` while you're iterating as that will reset the tracking repo. It will detect a dirty tree and abort, but if your tree is clean and you accidentally ran this target, use `git reflog` to recover your commit(s).
Iterate until you're ready to submit PRs. Once your code is ready, commit a change in the `./vendor/` directory, then generate the patches for ollama with
```shell
make -f Makefile.sync format-patches
```
make create-patches
```
> [!IMPORTANT]
> Once you have completed this step, it is safe to run `apply-patches` since your change is preserved in the patches.
In your `./vendor/` directory, create a branch, and cherry-pick the new commit to that branch, then submit a PR upstream to llama.cpp.

View File

@@ -1,4 +1,4 @@
int LLAMA_BUILD_NUMBER = 0;
char const *LLAMA_COMMIT = "2016f07bd106c73699ecbaace80f55db5ed95dac";
char const *LLAMA_COMMIT = "ba1cb19cdd0d92e012e0f6e009e0620f854b6afd";
char const *LLAMA_COMPILER = "";
char const *LLAMA_BUILD_TARGET = "";

View File

@@ -1,4 +0,0 @@
int LLAMA_BUILD_NUMBER = 0;
char const *LLAMA_COMMIT = "@FETCH_HEAD@";
char const *LLAMA_COMPILER = "";
char const *LLAMA_BUILD_TARGET = "";

View File

@@ -13,7 +13,6 @@ include include/llama-*.*
include examples/
include examples/llava/
include examples/llava/clip.*
include examples/llava/clip-impl.*
include examples/llava/llava.*
include src/
include src/llama.*

View File

@@ -2,11 +2,12 @@
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
#endif
#include "ggml.h"
#include "gguf.h"
#include "common.h"
#include "log.h"
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
#include "json-schema-to-grammar.h"
#include "llama.h"
#include <algorithm>
@@ -48,11 +49,31 @@
#include <sys/stat.h>
#include <unistd.h>
#endif
#if defined(LLAMA_USE_CURL)
#include <curl/curl.h>
#include <curl/easy.h>
#include <future>
#endif
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
#if defined(LLAMA_USE_CURL)
#ifdef __linux__
#include <linux/limits.h>
#elif defined(_WIN32)
# if !defined(PATH_MAX)
# define PATH_MAX MAX_PATH
# endif
#else
#include <sys/syslimits.h>
#endif
#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
#endif // LLAMA_USE_CURL
using json = nlohmann::ordered_json;
//
// CPU utils
//
@@ -443,53 +464,6 @@ void string_replace_all(std::string & s, const std::string & search, const std::
s = std::move(builder);
}
std::string regex_escape(const std::string & s) {
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
return std::regex_replace(s, special_chars, "\\$0");
}
std::string string_join(const std::vector<std::string> & values, const std::string & separator) {
std::ostringstream result;
for (size_t i = 0; i < values.size(); ++i) {
if (i > 0) {
result << separator;
}
result << values[i];
}
return result.str();
}
std::vector<std::string> string_split(const std::string & str, const std::string & delimiter) {
std::vector<std::string> parts;
size_t start = 0;
size_t end = str.find(delimiter);
while (end != std::string::npos) {
parts.push_back(str.substr(start, end - start));
start = end + delimiter.length();
end = str.find(delimiter, start);
}
parts.push_back(str.substr(start));
return parts;
}
std::string string_repeat(const std::string & str, size_t n) {
if (n == 0) {
return "";
}
std::string result;
result.reserve(str.length() * n);
for (size_t i = 0; i < n; ++i) {
result += str;
}
return result;
}
std::string string_from(bool value) {
return value ? "true" : "false";
}
@@ -830,7 +804,7 @@ std::string fs_get_cache_directory() {
if (getenv("LLAMA_CACHE")) {
cache_directory = std::getenv("LLAMA_CACHE");
} else {
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX)
#ifdef __linux__
if (std::getenv("XDG_CACHE_HOME")) {
cache_directory = std::getenv("XDG_CACHE_HOME");
} else {
@@ -840,9 +814,7 @@ std::string fs_get_cache_directory() {
cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
#elif defined(_WIN32)
cache_directory = std::getenv("LOCALAPPDATA");
#else
# error Unknown architecture
#endif
#endif // __linux__
cache_directory = ensure_trailing_slash(cache_directory);
cache_directory += "llama.cpp";
}
@@ -863,39 +835,45 @@ std::string fs_get_cache_file(const std::string & filename) {
//
// Model utils
//
struct common_init_result common_init_from_params(common_params & params) {
common_init_result iparams;
auto mparams = common_model_params_to_llama(params);
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
if (model == NULL) {
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
return iparams;
llama_model * model = nullptr;
if (!params.hf_repo.empty() && !params.hf_file.empty()) {
model = common_load_model_from_hf(params.hf_repo, params.hf_file, params.model, params.hf_token, mparams);
} else if (!params.model_url.empty()) {
model = common_load_model_from_url(params.model_url, params.model, params.hf_token, mparams);
} else {
model = llama_load_model_from_file(params.model.c_str(), mparams);
}
const llama_vocab * vocab = llama_model_get_vocab(model);
if (model == NULL) {
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.c_str());
return iparams;
}
if (params.reranking) {
bool ok = true;
if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
if (llama_token_bos(model) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: model does not have a BOS token, reranking will not work\n", __func__);
ok = false;
}
if (llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have an EOS token, reranking will not work\n", __func__);
if (llama_token_eos(model) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: model does not have an EOS token, reranking will not work\n", __func__);
ok = false;
}
if (llama_vocab_sep(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
if (llama_token_sep(model) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: model does not have a SEP token, reranking will not work\n", __func__);
ok = false;
}
if (!ok) {
llama_model_free(model);
llama_free_model(model);
return iparams;
}
@@ -903,40 +881,39 @@ struct common_init_result common_init_from_params(common_params & params) {
auto cparams = common_context_params_to_llama(params);
llama_context * lctx = llama_init_from_model(model, cparams);
llama_context * lctx = llama_new_context_with_model(model, cparams);
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
llama_model_free(model);
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str());
llama_free_model(model);
return iparams;
}
if (params.ctx_shift && !llama_kv_self_can_shift(lctx)) {
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) {
LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__);
params.ctx_shift = false;
}
if (!params.control_vectors.empty()) {
if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1;
if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_model_n_layer(model);
if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model);
const auto cvec = common_control_vector_load(params.control_vectors);
if (cvec.n_embd == -1) {
llama_free(lctx);
llama_model_free(model);
llama_free_model(model);
return iparams;
}
int err = llama_apply_adapter_cvec(
lctx,
cvec.data.data(),
cvec.data.size(),
cvec.n_embd,
params.control_vector_layer_start,
params.control_vector_layer_end);
int err = llama_control_vector_apply(lctx,
cvec.data.data(),
cvec.data.size(),
cvec.n_embd,
params.control_vector_layer_start,
params.control_vector_layer_end);
if (err) {
llama_free(lctx);
llama_model_free(model);
llama_free_model(model);
return iparams;
}
@@ -944,12 +921,12 @@ struct common_init_result common_init_from_params(common_params & params) {
// load and optionally apply lora adapters
for (auto & la : params.lora_adapters) {
llama_adapter_lora_ptr lora;
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
llama_lora_adapter_ptr lora;
lora.reset(llama_lora_adapter_init(model, la.path.c_str()));
if (lora == nullptr) {
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
llama_free(lctx);
llama_model_free(model);
llama_free_model(model);
return iparams;
}
@@ -958,17 +935,17 @@ struct common_init_result common_init_from_params(common_params & params) {
}
if (!params.lora_init_without_apply) {
common_set_adapter_lora(lctx, params.lora_adapters);
common_lora_adapters_apply(lctx, params.lora_adapters);
}
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__);
params.sampling.ignore_eos = false;
}
if (params.sampling.ignore_eos) {
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) {
for (llama_token i = 0; i < llama_n_vocab(model); i++) {
if (llama_token_is_eog(model, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
params.sampling.logit_bias.push_back({i, -INFINITY});
}
@@ -988,12 +965,9 @@ struct common_init_result common_init_from_params(common_params & params) {
if (params.warmup) {
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
llama_set_warmup(lctx, true);
std::vector<llama_token> tmp;
llama_token bos = llama_vocab_bos(vocab);
llama_token eos = llama_vocab_eos(vocab);
llama_token bos = llama_token_bos(model);
llama_token eos = llama_token_eos(model);
// some models (e.g. T5) don't have a BOS token
if (bos != LLAMA_TOKEN_NULL) {
tmp.push_back(bos);
@@ -1008,7 +982,7 @@ struct common_init_result common_init_from_params(common_params & params) {
if (llama_model_has_encoder(model)) {
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size()));
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
if (decoder_start_token_id == -1) {
decoder_start_token_id = bos;
}
tmp.clear();
@@ -1017,10 +991,9 @@ struct common_init_result common_init_from_params(common_params & params) {
if (llama_model_has_decoder(model)) {
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
}
llama_kv_self_clear(lctx);
llama_kv_cache_clear(lctx);
llama_synchronize(lctx);
llama_perf_context_reset(lctx);
llama_set_warmup(lctx, false);
}
iparams.model.reset(model);
@@ -1029,24 +1002,11 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams;
}
std::string get_model_endpoint() {
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
const char * hf_endpoint_env = getenv("HF_ENDPOINT");
const char * endpoint_env = model_endpoint_env ? model_endpoint_env : hf_endpoint_env;
std::string model_endpoint = "https://huggingface.co/";
if (endpoint_env) {
model_endpoint = endpoint_env;
if (model_endpoint.back() != '/') model_endpoint += '/';
}
return model_endpoint;
}
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
llama_clear_adapter_lora(ctx);
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_info> & lora) {
llama_lora_adapter_clear(ctx);
for (auto & la : lora) {
if (la.scale != 0.0f) {
llama_set_adapter_lora(ctx, la.ptr, la.scale);
llama_lora_adapter_set(ctx, la.ptr, la.scale);
}
}
}
@@ -1057,18 +1017,16 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
if (!params.devices.empty()) {
mparams.devices = params.devices.data();
}
if (params.n_gpu_layers != -1) {
mparams.n_gpu_layers = params.n_gpu_layers;
}
mparams.rpc_servers = params.rpc_servers.c_str();
mparams.main_gpu = params.main_gpu;
mparams.split_mode = params.split_mode;
mparams.tensor_split = params.tensor_split;
mparams.use_mmap = params.use_mmap;
mparams.use_mlock = params.use_mlock;
mparams.check_tensors = params.check_tensors;
if (params.kv_overrides.empty()) {
mparams.kv_overrides = NULL;
} else {
@@ -1076,13 +1034,6 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
mparams.kv_overrides = params.kv_overrides.data();
}
if (params.tensor_buft_overrides.empty()) {
mparams.tensor_buft_overrides = NULL;
} else {
GGML_ASSERT(params.tensor_buft_overrides.back().pattern == nullptr && "Tensor buffer overrides not terminated with empty pattern");
mparams.tensor_buft_overrides = params.tensor_buft_overrides.data();
}
return mparams;
}
@@ -1142,6 +1093,373 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p
return tpp;
}
#ifdef LLAMA_USE_CURL
#define CURL_MAX_RETRY 3
#define CURL_RETRY_DELAY_SECONDS 2
static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds) {
int remaining_attempts = max_attempts;
while (remaining_attempts > 0) {
LOG_INF("%s: Trying to download from %s (attempt %d of %d)...\n", __func__ , url.c_str(), max_attempts - remaining_attempts + 1, max_attempts);
CURLcode res = curl_easy_perform(curl);
if (res == CURLE_OK) {
return true;
}
int exponential_backoff_delay = std::pow(retry_delay_seconds, max_attempts - remaining_attempts) * 1000;
LOG_WRN("%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay);
remaining_attempts--;
std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
}
LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts);
return false;
}
static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
// Initialize libcurl
std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
if (!curl) {
LOG_ERR("%s: error initializing libcurl\n", __func__);
return false;
}
bool force_download = false;
// Set the URL, allow to follow http redirection
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
// Check if hf-token or bearer-token was specified
if (!hf_token.empty()) {
std::string auth_header = "Authorization: Bearer ";
auth_header += hf_token.c_str();
struct curl_slist *http_headers = NULL;
http_headers = curl_slist_append(http_headers, auth_header.c_str());
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers);
}
#if defined(_WIN32)
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
// operating system. Currently implemented under MS-Windows.
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
#endif
// Check if the file already exists locally
auto file_exists = std::filesystem::exists(path);
// If the file exists, check its JSON metadata companion file.
std::string metadata_path = path + ".json";
nlohmann::json metadata;
std::string etag;
std::string last_modified;
if (file_exists) {
// Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
std::ifstream metadata_in(metadata_path);
if (metadata_in.good()) {
try {
metadata_in >> metadata;
LOG_INF("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str());
if (metadata.contains("url") && metadata.at("url").is_string()) {
auto previous_url = metadata.at("url").get<std::string>();
if (previous_url != url) {
LOG_ERR("%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str());
return false;
}
}
if (metadata.contains("etag") && metadata.at("etag").is_string()) {
etag = metadata.at("etag");
}
if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
last_modified = metadata.at("lastModified");
}
} catch (const nlohmann::json::exception & e) {
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
return false;
}
}
} else {
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
}
// Send a HEAD request to retrieve the etag and last-modified headers
struct common_load_model_from_url_headers {
std::string etag;
std::string last_modified;
};
common_load_model_from_url_headers headers;
{
typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
static std::regex header_regex("([^:]+): (.*)\r\n");
static std::regex etag_regex("ETag", std::regex_constants::icase);
static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
std::string header(buffer, n_items);
std::smatch match;
if (std::regex_match(header, match, header_regex)) {
const std::string & key = match[1];
const std::string & value = match[2];
if (std::regex_match(key, match, etag_regex)) {
headers->etag = value;
} else if (std::regex_match(key, match, last_modified_regex)) {
headers->last_modified = value;
}
}
return n_items;
};
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress
curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
if (!was_perform_successful) {
return false;
}
long http_code = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
if (http_code != 200) {
// HEAD not supported, we don't know if the file has changed
// force trigger downloading
force_download = true;
LOG_ERR("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
}
}
bool should_download = !file_exists || force_download;
if (!should_download) {
if (!etag.empty() && etag != headers.etag) {
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str());
should_download = true;
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str());
should_download = true;
}
}
if (should_download) {
std::string path_temporary = path + ".downloadInProgress";
if (file_exists) {
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
if (remove(path.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
return false;
}
}
// Set the output file
struct FILE_deleter {
void operator()(FILE * f) const {
fclose(f);
}
};
std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "wb"));
if (!outfile) {
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str());
return false;
}
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
return fwrite(data, size, nmemb, (FILE *)fd);
};
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L);
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get());
// display download progress
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L);
// helper function to hide password in URL
auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
std::size_t protocol_pos = url.find("://");
if (protocol_pos == std::string::npos) {
return url; // Malformed URL
}
std::size_t at_pos = url.find('@', protocol_pos + 3);
if (at_pos == std::string::npos) {
return url; // No password in URL
}
return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos);
};
// start the download
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
if (!was_perform_successful) {
return false;
}
long http_code = 0;
curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
if (http_code < 200 || http_code >= 400) {
LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
return false;
}
// Causes file to be closed explicitly here before we rename it.
outfile.reset();
// Write the updated JSON metadata file.
metadata.update({
{"url", url},
{"etag", headers.etag},
{"lastModified", headers.last_modified}
});
std::ofstream(metadata_path) << metadata.dump(4);
LOG_INF("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
if (rename(path_temporary.c_str(), path.c_str()) != 0) {
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return false;
}
}
return true;
}
struct llama_model * common_load_model_from_url(
const std::string & model_url,
const std::string & local_path,
const std::string & hf_token,
const struct llama_model_params & params) {
// Basic validation of the model_url
if (model_url.empty()) {
LOG_ERR("%s: invalid model_url\n", __func__);
return NULL;
}
if (!common_download_file(model_url, local_path, hf_token)) {
return NULL;
}
// check for additional GGUFs split to download
int n_split = 0;
{
struct gguf_init_params gguf_params = {
/*.no_alloc = */ true,
/*.ctx = */ NULL,
};
auto * ctx_gguf = gguf_init_from_file(local_path.c_str(), gguf_params);
if (!ctx_gguf) {
LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, local_path.c_str());
return NULL;
}
auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT);
if (key_n_split >= 0) {
n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
}
gguf_free(ctx_gguf);
}
if (n_split > 1) {
char split_prefix[PATH_MAX] = {0};
char split_url_prefix[LLAMA_CURL_MAX_URL_LENGTH] = {0};
// Verify the first split file format
// and extract split URL and PATH prefixes
{
if (!llama_split_prefix(split_prefix, sizeof(split_prefix), local_path.c_str(), 0, n_split)) {
LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, local_path.c_str(), n_split);
return NULL;
}
if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model_url.c_str(), 0, n_split)) {
LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model_url.c_str(), n_split);
return NULL;
}
}
// Prepare download in parallel
std::vector<std::future<bool>> futures_download;
for (int idx = 1; idx < n_split; idx++) {
futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split, hf_token](int download_idx) -> bool {
char split_path[PATH_MAX] = {0};
llama_split_path(split_path, sizeof(split_path), split_prefix, download_idx, n_split);
char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0};
llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split);
return common_download_file(split_url, split_path, hf_token);
}, idx));
}
// Wait for all downloads to complete
for (auto & f : futures_download) {
if (!f.get()) {
return NULL;
}
}
}
return llama_load_model_from_file(local_path.c_str(), params);
}
struct llama_model * common_load_model_from_hf(
const std::string & repo,
const std::string & remote_path,
const std::string & local_path,
const std::string & hf_token,
const struct llama_model_params & params) {
// construct hugging face model url:
//
// --repo ggml-org/models --file tinyllama-1.1b/ggml-model-f16.gguf
// https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf
//
// --repo TheBloke/Mixtral-8x7B-v0.1-GGUF --file mixtral-8x7b-v0.1.Q4_K_M.gguf
// https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q4_K_M.gguf
//
std::string model_url = "https://huggingface.co/";
model_url += repo;
model_url += "/resolve/main/";
model_url += remote_path;
return common_load_model_from_url(model_url, local_path, hf_token, params);
}
#else
struct llama_model * common_load_model_from_url(
const std::string & /*model_url*/,
const std::string & /*local_path*/,
const std::string & /*hf_token*/,
const struct llama_model_params & /*params*/) {
LOG_WRN("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
return nullptr;
}
struct llama_model * common_load_model_from_hf(
const std::string & /*repo*/,
const std::string & /*remote_path*/,
const std::string & /*local_path*/,
const std::string & /*hf_token*/,
const struct llama_model_params & /*params*/) {
LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
return nullptr;
}
#endif // LLAMA_USE_CURL
//
// Batch utils
//
@@ -1238,23 +1556,21 @@ std::vector<llama_token> common_tokenize(
const std::string & text,
bool add_special,
bool parse_special) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
return common_tokenize(vocab, text, add_special, parse_special);
return common_tokenize(llama_get_model(ctx), text, add_special, parse_special);
}
std::vector<llama_token> common_tokenize(
const struct llama_vocab * vocab,
const struct llama_model * model,
const std::string & text,
bool add_special,
bool parse_special) {
// upper limit for the number of tokens
int n_tokens = text.length() + 2 * add_special;
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
@@ -1263,18 +1579,12 @@ std::vector<llama_token> common_tokenize(
}
std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
return common_token_to_piece(vocab, token, special);
}
std::string common_token_to_piece(const struct llama_vocab * vocab, llama_token token, bool special) {
std::string piece;
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special);
if (n_chars < 0) {
piece.resize(-n_chars);
int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
int check = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special);
GGML_ASSERT(check == -n_chars);
}
else {
@@ -1284,19 +1594,13 @@ std::string common_token_to_piece(const struct llama_vocab * vocab, llama_token
return piece;
}
std::string common_detokenize(const struct llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
return common_detokenize(vocab, tokens, special);
}
std::string common_detokenize(const struct llama_vocab * vocab, const std::vector<llama_token> & tokens, bool special) {
std::string common_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
std::string text;
text.resize(std::max(text.capacity(), tokens.size()));
int32_t n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
if (n_chars < 0) {
text.resize(-n_chars);
n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
}
@@ -1306,6 +1610,103 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
return text;
}
//
// Chat template utils
//
std::string common_get_builtin_chat_template(const struct llama_model * model) {
static const char * template_key = "tokenizer.chat_template";
// call with NULL buffer to get the total size of the string
int32_t res = llama_model_meta_val_str(model, template_key, NULL, 0);
if (res > 0) {
std::vector<char> model_template(res + 1, 0);
llama_model_meta_val_str(model, template_key, model_template.data(), model_template.size());
return std::string(model_template.data(), model_template.size() - 1);
}
return "";
}
bool common_chat_verify_template(const std::string & tmpl) {
llama_chat_message chat[] = {{"user", "test"}};
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
return res >= 0;
}
std::string common_chat_apply_template(const struct llama_model * model,
const std::string & tmpl,
const std::vector<common_chat_msg> & msgs,
bool add_ass) {
int alloc_size = 0;
bool fallback = false; // indicate if we must fallback to default chatml
std::vector<llama_chat_message> chat;
for (auto & msg : msgs) {
chat.push_back({msg.role.c_str(), msg.content.c_str()});
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
}
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
std::vector<char> buf(alloc_size);
// run the first time to get the total output length
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
// error: chat template is not supported
if (res < 0) {
if (ptr_tmpl != nullptr) {
// if the custom "tmpl" is not supported, we throw an error
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
throw std::runtime_error("this custom template is not supported");
} else {
// If the built-in template is not supported, we default to chatml
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
fallback = true;
}
}
// if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) {
buf.resize(res);
res = llama_chat_apply_template(
fallback ? nullptr : model,
fallback ? "chatml" : ptr_tmpl,
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
}
std::string formatted_chat(buf.data(), res);
return formatted_chat;
}
std::string common_chat_format_single(const struct llama_model * model,
const std::string & tmpl,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass) {
std::ostringstream ss;
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, past_msg, false);
std::vector<common_chat_msg> chat_new(past_msg);
// if the past_msg ends with a newline, we must preserve it in the formatted version
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
ss << "\n";
};
// format chat with new_msg
chat_new.push_back(new_msg);
auto fmt_new_msg = common_chat_apply_template(model, tmpl, chat_new, add_ass);
// get the diff part
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
return ss.str();
}
std::string common_chat_format_example(const struct llama_model * model,
const std::string & tmpl) {
std::vector<common_chat_msg> msgs = {
{"system", "You are a helpful assistant"},
{"user", "Hello"},
{"assistant", "Hi there"},
{"user", "How are you?"},
};
return common_chat_apply_template(model, tmpl, msgs, true);
}
//
// KV cache utils
//
@@ -1565,3 +1966,4 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
return result;
}

View File

@@ -4,7 +4,6 @@
#include "llama-cpp.h"
#include <set>
#include <string>
#include <vector>
#include <sstream>
@@ -25,11 +24,11 @@
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
struct common_adapter_lora_info {
struct common_lora_adapter_info {
std::string path;
float scale;
struct llama_adapter_lora * ptr;
struct llama_lora_adapter * ptr;
};
using llama_tokens = std::vector<llama_token>;
@@ -104,25 +103,6 @@ enum dimre_method {
DIMRE_METHOD_MEAN,
};
enum common_conversation_mode {
COMMON_CONVERSATION_MODE_DISABLED = 0,
COMMON_CONVERSATION_MODE_ENABLED = 1,
COMMON_CONVERSATION_MODE_AUTO = 2,
};
enum common_grammar_trigger_type {
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
};
struct common_grammar_trigger {
common_grammar_trigger_type type;
std::string value;
llama_token token = LLAMA_TOKEN_NULL;
};
// sampling parameters
struct common_params_sampling {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
@@ -148,7 +128,6 @@ struct common_params_sampling {
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float top_n_sigma = -1.00f;// -1.0 = disabled
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
bool ignore_eos = false;
@@ -169,10 +148,7 @@ struct common_params_sampling {
COMMON_SAMPLER_TYPE_TEMPERATURE,
};
std::string grammar; // optional BNF-like grammar to constrain sampling
bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
std::set<llama_token> preserved_tokens;
std::string grammar; // optional BNF-like grammar to constrain sampling
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
@@ -180,40 +156,28 @@ struct common_params_sampling {
std::string print() const;
};
struct common_params_model {
std::string path = ""; // model local path // NOLINT
std::string url = ""; // model url to download // NOLINT
std::string hf_repo = ""; // HF repo // NOLINT
std::string hf_file = ""; // HF file // NOLINT
};
struct common_params_speculative {
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
int32_t n_ctx = 0; // draft context size
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
float p_min = 0.9f; // minimum speculative decoding probability (greedy)
struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;
struct common_params_model model;
std::string model = ""; // draft model for speculative decoding // NOLINT
};
struct common_params_vocoder {
struct common_params_model model;
std::string hf_repo = ""; // HF repo // NOLINT
std::string hf_file = ""; // HF file // NOLINT
std::string speaker_file = ""; // speaker file path // NOLINT
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
};
enum common_reasoning_format {
COMMON_REASONING_FORMAT_NONE,
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`
std::string model = ""; // model path // NOLINT
std::string model_url = ""; // model url to download // NOLINT
};
struct common_params {
@@ -262,12 +226,13 @@ struct common_params {
struct common_params_speculative speculative;
struct common_params_vocoder vocoder;
struct common_params_model model;
std::string model = ""; // model path // NOLINT
std::string model_alias = ""; // model alias // NOLINT
std::string model_url = ""; // model url to download // NOLINT
std::string hf_token = ""; // HF token // NOLINT
std::string hf_repo = ""; // HF repo // NOLINT
std::string hf_file = ""; // HF file // NOLINT
std::string prompt = ""; // NOLINT
std::string system_prompt = ""; // NOLINT
std::string prompt_file = ""; // store the external prompt file name // NOLINT
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT
std::string input_prefix = ""; // string to prefix user inputs with // NOLINT
@@ -275,14 +240,14 @@ struct common_params {
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
std::string logits_file = ""; // file for saving *all* logits // NOLINT
std::string rpc_servers = ""; // comma separated list of RPC servers // NOLINT
std::vector<std::string> in_files; // all input files
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
std::vector<llama_model_kv_override> kv_overrides;
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_adapter_lora_apply)
std::vector<common_adapter_lora_info> lora_adapters; // lora adapter path with user defined scale
bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply)
std::vector<common_lora_adapter_info> lora_adapters; // lora adapter path with user defined scale
std::vector<common_control_vector_load_info> control_vectors; // control vector with user defined scale
@@ -306,11 +271,11 @@ struct common_params {
bool kl_divergence = false; // compute KL divergence
bool usage = false; // print usage
bool completion = false; // print source-able completion script
bool use_color = false; // use color to distinguish generations and inputs
bool special = false; // enable special token output
bool interactive = false; // interactive mode
bool interactive_first = false; // wait for user input immediately
bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix)
bool prompt_cache_all = false; // save user input and generations to prompt cache
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it
@@ -333,15 +298,11 @@ struct common_params {
bool warmup = true; // warmup run
bool check_tensors = false; // validate tensor data
bool single_turn = false; // single turn chat conversation
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO;
// multimodal models (see examples/llava)
struct common_params_model mmproj;
std::string mmproj = ""; // path to multimodal projector // NOLINT
std::vector<std::string> image; // path to image file(s)
// embedding
@@ -361,9 +322,7 @@ struct common_params {
std::string hostname = "127.0.0.1";
std::string public_path = ""; // NOLINT
std::string chat_template = ""; // NOLINT
bool use_jinja = false; // NOLINT
bool enable_chat_template = true;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
std::vector<std::string> api_keys;
@@ -401,6 +360,8 @@ struct common_params {
int32_t i_pos = -1; // position of the passkey in the junk text
// imatrix params
std::string out_file = "imatrix.dat"; // save the resulting imatrix to this file
int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations
int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations
int32_t i_chunk = 0; // start processing from this chunk
@@ -412,16 +373,16 @@ struct common_params {
int n_pca_batch = 100;
int n_pca_iterations = 1000;
dimre_method cvector_dimre_method = DIMRE_METHOD_PCA;
std::string cvector_outfile = "control_vector.gguf";
std::string cvector_positive_file = "examples/cvector-generator/positive.txt";
std::string cvector_negative_file = "examples/cvector-generator/negative.txt";
bool spm_infill = false; // suffix/prefix/middle pattern for infill
std::string lora_outfile = "ggml-lora-merged-f16.gguf";
// batched-bench params
bool batched_bench_output_jsonl = false;
// common params
std::string out_file; // output filename for all example programs
};
// call once at the start of a program if it uses libcommon
@@ -440,13 +401,13 @@ bool set_process_priority(enum ggml_sched_priority prio);
//
#ifdef __GNUC__
# if defined(__MINGW32__) && !defined(__clang__)
# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
# else
# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
# endif
#ifdef __MINGW32__
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
#else
# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...)
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
#endif
#else
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...)
#endif
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
@@ -455,14 +416,8 @@ std::string string_format(const char * fmt, ...);
std::string string_strip(const std::string & str);
std::string string_get_sortable_timestamp();
std::string string_join(const std::vector<std::string> & values, const std::string & separator);
std::vector<std::string> string_split(const std::string & str, const std::string & delimiter);
std::string string_repeat(const std::string & str, size_t n);
void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
std::string regex_escape(const std::string & s);
template<class T>
static std::vector<T> string_split(const std::string & str, char delim) {
static_assert(!std::is_same<T, std::string>::value, "Please use the specialized version for std::string");
@@ -499,11 +454,6 @@ static bool string_starts_with(const std::string & str,
return str.rfind(prefix, 0) == 0;
}
static bool string_ends_with(const std::string & str,
const std::string & suffix) { // While we wait for C++20's std::string::ends_with...
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
}
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
void string_process_escapes(std::string & input);
@@ -531,7 +481,7 @@ struct common_init_result {
llama_model_ptr model;
llama_context_ptr context;
std::vector<llama_adapter_lora_ptr> lora;
std::vector<llama_lora_adapter_ptr> lora;
};
struct common_init_result common_init_from_params(common_params & params);
@@ -540,10 +490,20 @@ struct llama_model_params common_model_params_to_llama ( common_params
struct llama_context_params common_context_params_to_llama(const common_params & params);
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);
// clear LoRA adapters from context, then apply new list of adapters
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
struct llama_model * common_load_model_from_url(
const std::string & model_url,
const std::string & local_path,
const std::string & hf_token,
const struct llama_model_params & params);
struct llama_model * common_load_model_from_hf(
const std::string & repo,
const std::string & remote_path,
const std::string & local_path,
const std::string & hf_token,
const struct llama_model_params & params);
std::string get_model_endpoint();
// clear LoRA adapters from context, then apply new list of adapters
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_info> & lora);
//
// Batch utils
@@ -581,7 +541,7 @@ std::vector<llama_token> common_tokenize(
bool parse_special = false);
std::vector<llama_token> common_tokenize(
const struct llama_vocab * vocab,
const struct llama_model * model,
const std::string & text,
bool add_special,
bool parse_special = false);
@@ -593,23 +553,48 @@ std::string common_token_to_piece(
llama_token token,
bool special = true);
std::string common_token_to_piece(
const struct llama_vocab * vocab,
llama_token token,
bool special = true);
// detokenizes a vector of tokens into a string
// should work similar to Python's `tokenizer.decode`
// optionally renders special/control tokens
std::string common_detokenize(
const struct llama_context * ctx,
llama_context * ctx,
const std::vector<llama_token> & tokens,
bool special = true);
std::string common_detokenize(
const struct llama_vocab * vocab,
const std::vector<llama_token> & tokens,
bool special = true);
//
// Chat template utils
//
// same with llama_chat_message, but uses std::string
struct common_chat_msg {
std::string role;
std::string content;
};
// Get the built-in chat template for the model. Return empty string if not present.
std::string common_get_builtin_chat_template(const struct llama_model * model);
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool common_chat_verify_template(const std::string & tmpl);
// CPP wrapper for llama_chat_apply_template
// If the built-in template is not supported, we default to chatml
// If the custom "tmpl" is not supported, we throw an error
std::string common_chat_apply_template(const struct llama_model * model,
const std::string & tmpl,
const std::vector<common_chat_msg> & chat,
bool add_ass);
// Format single message, while taking into account the position of that message in chat history
std::string common_chat_format_single(const struct llama_model * model,
const std::string & tmpl,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass);
// Returns an example of formatted chat
std::string common_chat_format_example(const struct llama_model * model,
const std::string & tmpl);
//
// KV cache utils

View File

@@ -1,6 +1,4 @@
#include "json-schema-to-grammar.h"
#include "common.h"
#include <algorithm>
#include <fstream>
#include <map>
@@ -13,6 +11,11 @@
using json = nlohmann::ordered_json;
template <typename Iterator>
static std::string join(Iterator begin, Iterator end, const std::string & separator);
static std::string repeat(const std::string & str, size_t n);
static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
auto has_max = max_items != std::numeric_limits<int>::max();
@@ -125,8 +128,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
if (sub_len > 0) {
auto from_sub = from.substr(i + 1);
auto to_sub = to.substr(i + 1);
auto sub_zeros = string_repeat("0", sub_len);
auto sub_nines = string_repeat("9", sub_len);
auto sub_zeros = repeat("0", sub_len);
auto sub_nines = repeat("9", sub_len);
auto to_reached = false;
out << "(";
@@ -185,8 +188,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
auto max_digits = max_s.length();
for (auto digits = min_digits; digits < max_digits; digits++) {
uniform_range(min_s, string_repeat("9", digits));
min_s = "1" + string_repeat("0", digits);
uniform_range(min_s, repeat("9", digits));
min_s = "1" + repeat("0", digits);
out << " | ";
}
uniform_range(min_s, max_s);
@@ -264,7 +267,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
throw std::runtime_error("At least one of min_value or max_value must be set");
}
const std::string SPACE_RULE = "| \" \" | \"\\n\"{1,2} [ \\t]{0,20}";
const std::string SPACE_RULE = "| \" \" | \"\\n\" [ \\t]{0,20}";
struct BuiltinRule {
std::string content;
@@ -315,6 +318,49 @@ std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
template <typename Iterator>
std::string join(Iterator begin, Iterator end, const std::string & separator) {
std::ostringstream result;
if (begin != end) {
result << *begin;
for (Iterator it = begin + 1; it != end; ++it) {
result << separator << *it;
}
}
return result.str();
}
static std::vector<std::string> split(const std::string & str, const std::string & delimiter) {
std::vector<std::string> tokens;
size_t start = 0;
size_t end = str.find(delimiter);
while (end != std::string::npos) {
tokens.push_back(str.substr(start, end - start));
start = end + delimiter.length();
end = str.find(delimiter, start);
}
tokens.push_back(str.substr(start));
return tokens;
}
static std::string repeat(const std::string & str, size_t n) {
if (n == 0) {
return "";
}
std::string result;
result.reserve(str.length() * n);
for (size_t i = 0; i < n; ++i) {
result += str;
}
return result;
}
static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch &)> & replacement) {
std::smatch match;
std::string result;
@@ -343,7 +389,6 @@ static std::string format_literal(const std::string & literal) {
class SchemaConverter {
private:
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
std::function<json(const std::string &)> _fetch_json;
bool _dotall;
std::unordered_map<std::string, std::string> _rules;
@@ -373,7 +418,7 @@ private:
for (size_t i = 0; i < alt_schemas.size(); i++) {
rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i)));
}
return string_join(rules, " | ");
return join(rules.begin(), rules.end(), " | ");
}
std::string _visit_pattern(const std::string & pattern, const std::string & name) {
@@ -436,7 +481,7 @@ private:
for (const auto & item : ret) {
results.push_back(to_rule(item));
}
return std::make_pair(string_join(results, " "), false);
return std::make_pair(join(results.begin(), results.end(), " "), false);
};
while (i < length) {
@@ -494,7 +539,7 @@ private:
}
curly_brackets += '}';
i++;
auto nums = string_split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
auto nums = split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
int min_times = 0;
int max_times = std::numeric_limits<int>::max();
try {
@@ -809,7 +854,7 @@ public:
return;
}
std::string pointer = ref.substr(ref.find('#') + 1);
std::vector<std::string> tokens = string_split(pointer, "/");
std::vector<std::string> tokens = split(pointer, "/");
for (size_t i = 1; i < tokens.size(); ++i) {
std::string sel = tokens[i];
if (target.is_null() || !target.contains(sel)) {
@@ -860,7 +905,7 @@ public:
for (const auto & v : schema["enum"]) {
enum_values.push_back(_generate_constant_rule(v));
}
return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space");
return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space");
} else if ((schema_type.is_null() || schema_type == "object")
&& (schema.contains("properties") ||
(schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
@@ -974,10 +1019,10 @@ public:
void check_errors() {
if (!_errors.empty()) {
throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
throw std::runtime_error("JSON schema conversion failed:\n" + join(_errors.begin(), _errors.end(), "\n"));
}
if (!_warnings.empty()) {
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str());
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", join(_warnings.begin(), _warnings.end(), "; ").c_str());
}
}
@@ -990,35 +1035,11 @@ public:
}
};
std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
#ifdef LLAMA_USE_LLGUIDANCE
if (!force_gbnf) {
return "%llguidance {}\nstart: %json " + schema.dump();
}
#else
(void)force_gbnf;
#endif // LLAMA_USE_LLGUIDANCE
return build_grammar([&](const common_grammar_builder & callbacks) {
auto copy = schema;
callbacks.resolve_refs(copy);
callbacks.add_schema("", copy);
});
}
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall);
common_grammar_builder builder {
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
return converter._add_rule(name, rule);
},
/* .add_schema = */ [&](const std::string & name, const nlohmann::ordered_json & schema) {
return converter.visit(schema, name == "root" ? "" : name);
},
/* .resolve_refs = */ [&](nlohmann::ordered_json & schema) {
converter.resolve_refs(schema, "");
}
};
cb(builder);
std::string json_schema_to_grammar(const json & schema) {
SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false);
auto copy = schema;
converter.resolve_refs(copy, "input");
converter.visit(copy, "");
converter.check_errors();
return converter.format_grammar();
}

View File

@@ -5,17 +5,4 @@
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
bool force_gbnf = false);
struct common_grammar_builder {
std::function<std::string(const std::string &, const std::string &)> add_rule;
std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;
std::function<void(nlohmann::ordered_json &)> resolve_refs;
};
struct common_grammar_options {
bool dotall = false;
};
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});
std::string json_schema_to_grammar(const nlohmann::ordered_json& schema);

View File

@@ -1,6 +1,5 @@
#include "log.h"
#include <chrono>
#include <condition_variable>
#include <cstdarg>
#include <cstdio>
@@ -15,6 +14,16 @@ void common_log_set_verbosity_thold(int verbosity) {
common_log_verbosity_thold = verbosity;
}
#define LOG_COL_DEFAULT "\033[0m"
#define LOG_COL_BOLD "\033[1m"
#define LOG_COL_RED "\033[31m"
#define LOG_COL_GREEN "\033[32m"
#define LOG_COL_YELLOW "\033[33m"
#define LOG_COL_BLUE "\033[34m"
#define LOG_COL_MAGENTA "\033[35m"
#define LOG_COL_CYAN "\033[36m"
#define LOG_COL_WHITE "\033[37m"
static int64_t t_us() {
return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
}
@@ -197,7 +206,6 @@ public:
vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args_copy);
}
#endif
va_end(args_copy);
}
entry.level = level;

Some files were not shown because too many files have changed in this diff Show More