Compare commits
294 Commits
v0.5.8-rc6
...
parth/samp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4450f871db | ||
|
|
5ec6bb52a0 | ||
|
|
1fd9967558 | ||
|
|
131f0355a5 | ||
|
|
ce929984a3 | ||
|
|
4b34930a31 | ||
|
|
74bd09652d | ||
|
|
fb6252d786 | ||
|
|
c794fef2f2 | ||
|
|
00ebda8cc4 | ||
|
|
d14ce75b95 | ||
|
|
2d6eac9084 | ||
|
|
3ed7ad3ab3 | ||
|
|
6d1103048e | ||
|
|
0ff28758b3 | ||
|
|
d3e9ca3eda | ||
|
|
0fbfcf3c9c | ||
|
|
0c220935bd | ||
|
|
ffbfe833da | ||
|
|
42a14f7f63 | ||
|
|
f8c3dbe5b5 | ||
|
|
b078dd157c | ||
|
|
2ddacd7516 | ||
|
|
da0e345200 | ||
|
|
df94175a0f | ||
|
|
61a8825216 | ||
|
|
021dcf089d | ||
|
|
bf24498b1e | ||
|
|
95e271d98f | ||
|
|
364629b8d6 | ||
|
|
108fe02165 | ||
|
|
4561fff36e | ||
|
|
50b5962042 | ||
|
|
e27e4a3c1b | ||
|
|
088514bbd4 | ||
|
|
2c8b484643 | ||
|
|
8294676150 | ||
|
|
ef378ad673 | ||
|
|
2d2247e59e | ||
|
|
7bf793a600 | ||
|
|
282bfaaa95 | ||
|
|
9679f40146 | ||
|
|
3892c3a703 | ||
|
|
4e320b8b90 | ||
|
|
eb2b22b042 | ||
|
|
4ea4d2b189 | ||
|
|
8d76fa23ef | ||
|
|
74b44fdf8f | ||
|
|
65b88c544f | ||
|
|
a422ba39c9 | ||
|
|
d2ec22371e | ||
|
|
033cec232a | ||
|
|
543240fb5f | ||
|
|
4bed739259 | ||
|
|
80c7ce381b | ||
|
|
ccfd41c4f0 | ||
|
|
3e102b7dad | ||
|
|
ec46f3286c | ||
|
|
5e2e0b46b1 | ||
|
|
45a13b1dec | ||
|
|
5c0b663969 | ||
|
|
30d7a59ba8 | ||
|
|
4aeb67ef4c | ||
|
|
3ba91634c1 | ||
|
|
1b7433b71e | ||
|
|
a70820daa0 | ||
|
|
6b45b1d6b4 | ||
|
|
85ab552028 | ||
|
|
b3af953a55 | ||
|
|
ad4e0bf3be | ||
|
|
aee28501b5 | ||
|
|
83f0ec8269 | ||
|
|
c6b6938b3a | ||
|
|
fb4664fcec | ||
|
|
20e3593863 | ||
|
|
63a394068c | ||
|
|
ab39e08eb9 | ||
|
|
11bfa62796 | ||
|
|
f63e62e546 | ||
|
|
65b0f329d1 | ||
|
|
06007c0a18 | ||
|
|
a8e83a7654 | ||
|
|
475005504e | ||
|
|
2c40c4d35e | ||
|
|
e95278932b | ||
|
|
9d2a20a763 | ||
|
|
2e54d72fc3 | ||
|
|
6b32a2d549 | ||
|
|
c5cbe4fc2a | ||
|
|
f888912870 | ||
|
|
9e4642e9b3 | ||
|
|
6b0486c216 | ||
|
|
d368c039f0 | ||
|
|
9b54267e69 | ||
|
|
46bb0169c4 | ||
|
|
8934324b72 | ||
|
|
0e886595bf | ||
|
|
c62861f4fa | ||
|
|
0df1800436 | ||
|
|
631fecc6d9 | ||
|
|
4346c2409d | ||
|
|
4b037a97dc | ||
|
|
5f74d1fd47 | ||
|
|
4dcf80167a | ||
|
|
26a26998fb | ||
|
|
9926eae015 | ||
|
|
8585b7b151 | ||
|
|
7e34f4fbfa | ||
|
|
fe776293f7 | ||
|
|
d8a5d96b98 | ||
|
|
757668c42f | ||
|
|
96ec8afd09 | ||
|
|
e093db92c4 | ||
|
|
a1cda80bcb | ||
|
|
4614fafae0 | ||
|
|
4100ed7bdd | ||
|
|
f52b2615ef | ||
|
|
25f9b152f9 | ||
|
|
6da8b6a879 | ||
|
|
0daaaef8c9 | ||
|
|
98272fbd58 | ||
|
|
b27e8f3f10 | ||
|
|
45df786f09 | ||
|
|
daaf42e4a4 | ||
|
|
2dc60d4620 | ||
|
|
b5312f30e8 | ||
|
|
26c2e0bd35 | ||
|
|
bf920883d5 | ||
|
|
58b9ec1f6b | ||
|
|
7bae7fa5ce | ||
|
|
764e199d67 | ||
|
|
bfce55db3d | ||
|
|
bab6f34dc0 | ||
|
|
0682dae027 | ||
|
|
1f6986e919 | ||
|
|
4289c74359 | ||
|
|
25248f4bd5 | ||
|
|
a7e63b82be | ||
|
|
b70fc4d51e | ||
|
|
e2252d0fc6 | ||
|
|
cae5d4d4ea | ||
|
|
05a01fdecb | ||
|
|
8fe6f69f28 | ||
|
|
1fdb351c37 | ||
|
|
7a01ad7614 | ||
|
|
55ab9f371a | ||
|
|
fefbf8f74b | ||
|
|
b428ddd796 | ||
|
|
ba7d31240e | ||
|
|
d25efe3954 | ||
|
|
36dfb906bb | ||
|
|
a6f0f908b9 | ||
|
|
3b1ddb2b3a | ||
|
|
1579c4f06d | ||
|
|
3519dd1c6e | ||
|
|
e41c4cbea7 | ||
|
|
ee048b76d4 | ||
|
|
af68d60a58 | ||
|
|
21aa666a1e | ||
|
|
ee141cc821 | ||
|
|
55e5776c44 | ||
|
|
854a9195f3 | ||
|
|
96a97adf9b | ||
|
|
e75c6126e9 | ||
|
|
cda6f5c66c | ||
|
|
bebb6823c0 | ||
|
|
31e472baa4 | ||
|
|
657685e85d | ||
|
|
a14912858e | ||
|
|
eed11ded30 | ||
|
|
b42aba40ed | ||
|
|
25885e5335 | ||
|
|
98d44fa39d | ||
|
|
2099e2d267 | ||
|
|
0c1041ad85 | ||
|
|
c245b0406f | ||
|
|
8b194b7520 | ||
|
|
3e8b8a1933 | ||
|
|
41dc280491 | ||
|
|
53d2990d9b | ||
|
|
e185c08ad9 | ||
|
|
2412adf42b | ||
|
|
be2ac1ed93 | ||
|
|
dc13813a03 | ||
|
|
d6af13efed | ||
|
|
a59f665235 | ||
|
|
688925aca9 | ||
|
|
76e903cf9d | ||
|
|
a5272130c4 | ||
|
|
d7d7e99662 | ||
|
|
2db96c18e7 | ||
|
|
e12af460ed | ||
|
|
3ad4bc8afe | ||
|
|
0d694793f2 | ||
|
|
e91ae3d47d | ||
|
|
6ecd7f64ba | ||
|
|
888855675e | ||
|
|
b16367b4b2 | ||
|
|
a499390648 | ||
|
|
4df98f3eb5 | ||
|
|
348b3e0983 | ||
|
|
0b7e1676eb | ||
|
|
314573bfe8 | ||
|
|
4604b10306 | ||
|
|
8c13cfa4dd | ||
|
|
7cfd4aee4d | ||
|
|
68bac1e0a6 | ||
|
|
f53f4198c3 | ||
|
|
2192a28eed | ||
|
|
5d81c1a184 | ||
|
|
5c5535c064 | ||
|
|
e5bcc51ae1 | ||
|
|
bd6a7d5e64 | ||
|
|
14b5a9a150 | ||
|
|
ba9ec3d05e | ||
|
|
7c168b08c9 | ||
|
|
3d4cc7833c | ||
|
|
351a85d9ea | ||
|
|
bda4ef6c56 | ||
|
|
1e438b237c | ||
|
|
d721a02e7d | ||
|
|
778603a818 | ||
|
|
3c874df46e | ||
|
|
d2eb226c91 | ||
|
|
e13e7c8d94 | ||
|
|
78f403ff45 | ||
|
|
5f8c03189e | ||
|
|
08a299e1d0 | ||
|
|
7b5d916a9a | ||
|
|
33ad61b112 | ||
|
|
716e365615 | ||
|
|
3b4424ff98 | ||
|
|
f9c7ead160 | ||
|
|
5930aaeb1a | ||
|
|
faf67db089 | ||
|
|
0667baddc6 | ||
|
|
d006e1e09b | ||
|
|
df2680b4b9 | ||
|
|
010313bb63 | ||
|
|
5296f487a8 | ||
|
|
f05774b04c | ||
|
|
6600bd7d91 | ||
|
|
ed443a0393 | ||
|
|
6945617af5 | ||
|
|
7916f55009 | ||
|
|
d650ad398f | ||
|
|
d223f3b697 | ||
|
|
60830695c2 | ||
|
|
01d9a46854 | ||
|
|
d773b7d671 | ||
|
|
4d4463b2bd | ||
|
|
0e38297f87 | ||
|
|
7e13f568dc | ||
|
|
58245413f4 | ||
|
|
8cf16063a5 | ||
|
|
3a4449e2f1 | ||
|
|
10d59d5f90 | ||
|
|
a4f69a0191 | ||
|
|
82658c3eec | ||
|
|
378d6e1e6a | ||
|
|
afa55bc70c | ||
|
|
49df03da9a | ||
|
|
0189bdd0b7 | ||
|
|
f4711da7bd | ||
|
|
38117fba83 | ||
|
|
1f766c36fb | ||
|
|
484a99e428 | ||
|
|
ec6121c331 | ||
|
|
b86c0a1500 | ||
|
|
7e402ebb8c | ||
|
|
b901a712c6 | ||
|
|
abb8dd57f8 | ||
|
|
a400df48c0 | ||
|
|
6ab4ba4c26 | ||
|
|
e8d4eb3e68 | ||
|
|
ae7e368f75 | ||
|
|
31acd1ebf9 | ||
|
|
9a4757ae66 | ||
|
|
7814019708 | ||
|
|
b698f9a0d8 | ||
|
|
32285a6d19 | ||
|
|
1c198977ec | ||
|
|
330b6c50b0 | ||
|
|
928911bc68 | ||
|
|
5b446cc815 | ||
|
|
451c1596af | ||
|
|
932bded12f | ||
|
|
070ad913ac | ||
|
|
8d8b9f83ae | ||
|
|
f00d359a67 | ||
|
|
291def6adb | ||
|
|
cd3fbf1c49 | ||
|
|
c852b8e021 | ||
|
|
d8932c55e7 |
4
.gitattributes
vendored
4
.gitattributes
vendored
@@ -15,6 +15,10 @@ ml/backend/**/*.cu linguist-vendored
|
|||||||
ml/backend/**/*.cuh linguist-vendored
|
ml/backend/**/*.cuh linguist-vendored
|
||||||
ml/backend/**/*.m linguist-vendored
|
ml/backend/**/*.m linguist-vendored
|
||||||
ml/backend/**/*.metal linguist-vendored
|
ml/backend/**/*.metal linguist-vendored
|
||||||
|
ml/backend/**/CMakeLists.txt linguist-vendored
|
||||||
|
|
||||||
|
llama/build-info.cpp linguist-generated
|
||||||
|
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.s linguist-generated
|
||||||
|
|
||||||
* text=auto
|
* text=auto
|
||||||
*.go text eol=lf
|
*.go text eol=lf
|
||||||
|
|||||||
8
.github/ISSUE_TEMPLATE/10_bug_report.yml
vendored
8
.github/ISSUE_TEMPLATE/10_bug_report.yml
vendored
@@ -9,6 +9,14 @@ body:
|
|||||||
description: What happened? What did you expect to happen?
|
description: What happened? What did you expect to happen?
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
id: logs
|
||||||
|
attributes:
|
||||||
|
label: Relevant log output
|
||||||
|
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md#how-to-troubleshoot-issues) for details.
|
||||||
|
render: shell
|
||||||
|
validations:
|
||||||
|
required: false
|
||||||
- type: dropdown
|
- type: dropdown
|
||||||
id: os
|
id: os
|
||||||
attributes:
|
attributes:
|
||||||
|
|||||||
62
.github/workflows/release.yaml
vendored
62
.github/workflows/release.yaml
vendored
@@ -111,13 +111,13 @@ jobs:
|
|||||||
- os: windows
|
- os: windows
|
||||||
arch: amd64
|
arch: amd64
|
||||||
preset: 'CUDA 12'
|
preset: 'CUDA 12'
|
||||||
install: https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_551.61_windows.exe
|
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
|
||||||
cuda-version: '12.4'
|
cuda-version: '12.8'
|
||||||
- os: windows
|
- os: windows
|
||||||
arch: amd64
|
arch: amd64
|
||||||
preset: 'ROCm 6'
|
preset: 'ROCm 6'
|
||||||
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe
|
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
||||||
rocm-version: '6.1'
|
rocm-version: '6.2'
|
||||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||||
environment: release
|
environment: release
|
||||||
env:
|
env:
|
||||||
@@ -160,6 +160,10 @@ jobs:
|
|||||||
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||||
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
- if: matrix.preset == 'CPU'
|
||||||
|
run: |
|
||||||
|
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
||||||
uses: actions/cache/save@v4
|
uses: actions/cache/save@v4
|
||||||
with:
|
with:
|
||||||
@@ -242,7 +246,7 @@ jobs:
|
|||||||
dist\${{ matrix.os }}-${{ matrix.arch }}-app.exe
|
dist\${{ matrix.os }}-${{ matrix.arch }}-app.exe
|
||||||
|
|
||||||
windows-sign:
|
windows-sign:
|
||||||
runs-on: windows
|
runs-on: windows-2022
|
||||||
environment: release
|
environment: release
|
||||||
needs: [windows-depends, windows-build]
|
needs: [windows-depends, windows-build]
|
||||||
steps:
|
steps:
|
||||||
@@ -303,32 +307,40 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: docker/setup-buildx-action@v3
|
- uses: docker/setup-buildx-action@v3
|
||||||
|
- uses: docker/build-push-action@v6
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
platforms: ${{ matrix.os }}/${{ matrix.arch }}
|
||||||
|
target: ${{ matrix.target }}
|
||||||
|
build-args: |
|
||||||
|
GOFLAGS=${{ env.GOFLAGS }}
|
||||||
|
CGO_CFLAGS=${{ env.CGO_CFLAGS }}
|
||||||
|
CGO_CXXFLAGS=${{ env.CGO_CXXFLAGS }}
|
||||||
|
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||||
|
cache-from: type=registry,ref=ollama/ollama:latest
|
||||||
|
cache-to: type=inline
|
||||||
- run: |
|
- run: |
|
||||||
sudo apt-get update && sudo apt-get install pigz
|
for COMPONENT in bin/* lib/ollama/*; do
|
||||||
docker buildx build --platform $PLATFORM --target ${{ matrix.target }} --build-arg GOFLAGS --build-arg CGO_CFLAGS --build-arg CGO_CXXFLAGS --output type=local,dest=dist/$PLATFORM .
|
case "$COMPONENT" in
|
||||||
|
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
for COMPONENTS in dist/$PLATFORM/* dist/$PLATFORM/lib/ollama/*; do
|
lib/ollama/*.so) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
if [ -d "$COMPONENTS" ]; then
|
lib/ollama/cuda_v11) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
case "$COMPONENTS" in
|
lib/ollama/cuda_v12) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
*/bin) echo $COMPONENTS >>dist/ollama-${PLATFORM//\//-}.tar.in ;;
|
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||||
*/lib/ollama) echo $COMPONENTS >>dist/ollama-${PLATFORM//\//-}.tar.in;;
|
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||||
*/lib/ollama/cuda_v11) echo $COMPONENTS >>dist/ollama-${PLATFORM//\//-}.tar.in;;
|
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||||
*/lib/ollama/cuda_v12) echo $COMPONENTS >>dist/ollama-${PLATFORM//\//-}.tar.in;;
|
|
||||||
*/lib/ollama/cuda_jetpack5) echo $COMPONENTS >>dist/ollama-${PLATFORM//\//-}-jetpack5.tar.in ;;
|
|
||||||
*/lib/ollama/cuda_jetpack6) echo $COMPONENTS >>dist/ollama-${PLATFORM//\//-}-jetpack6.tar.in ;;
|
|
||||||
*/lib/ollama/rocm) echo $COMPONENTS >>dist/ollama-${PLATFORM//\//-}-rocm.tar.in ;;
|
|
||||||
esac
|
esac
|
||||||
fi
|
|
||||||
done
|
done
|
||||||
|
working-directory: dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||||
for ARCHIVE in dist/*.tar.in; do tar c -T $ARCHIVE --strip-components 3 | pigz -9cv >${ARCHIVE//.*/}.tgz; done
|
- run: |
|
||||||
env:
|
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
|
||||||
PLATFORM: ${{ matrix.os }}/${{ matrix.arch }}
|
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
|
||||||
|
done
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: dist-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.target }}
|
name: dist-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.target }}
|
||||||
path: |
|
path: |
|
||||||
dist/*.tgz
|
*.tgz
|
||||||
|
|
||||||
# Build each Docker variant (OS, arch, and flavor) separately. Using QEMU is unreliable and slower.
|
# Build each Docker variant (OS, arch, and flavor) separately. Using QEMU is unreliable and slower.
|
||||||
docker-build-push:
|
docker-build-push:
|
||||||
@@ -362,7 +374,7 @@ jobs:
|
|||||||
GOFLAGS: ${{ needs.setup-environment.outputs.GOFLAGS }}
|
GOFLAGS: ${{ needs.setup-environment.outputs.GOFLAGS }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: docker/setup-buildx-action@v2
|
- uses: docker/setup-buildx-action@v3
|
||||||
- uses: docker/login-action@v3
|
- uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ vars.DOCKER_USER }}
|
username: ${{ vars.DOCKER_USER }}
|
||||||
|
|||||||
90
.github/workflows/test.yaml
vendored
90
.github/workflows/test.yaml
vendored
@@ -78,10 +78,10 @@ jobs:
|
|||||||
include:
|
include:
|
||||||
- preset: CPU
|
- preset: CPU
|
||||||
- preset: CUDA
|
- preset: CUDA
|
||||||
install: https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_522.06_windows.exe
|
install: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
|
||||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
|
||||||
- preset: ROCm
|
- preset: ROCm
|
||||||
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe
|
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
||||||
flags: '-DAMDGPU_TARGETS=gfx1010'
|
flags: '-DAMDGPU_TARGETS=gfx1010'
|
||||||
runs-on: windows
|
runs-on: windows
|
||||||
steps:
|
steps:
|
||||||
@@ -102,7 +102,7 @@ jobs:
|
|||||||
$ErrorActionPreference = "Stop"
|
$ErrorActionPreference = "Stop"
|
||||||
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||||
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
||||||
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_11.8", "nvcc_11.8", "cublas_11.8", "cublas_dev_11.8")) -NoNewWindow -Wait
|
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_11.3", "nvcc_11.3", "cublas_11.3", "cublas_dev_11.3")) -NoNewWindow -Wait
|
||||||
}
|
}
|
||||||
|
|
||||||
$cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path
|
$cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path
|
||||||
@@ -140,6 +140,13 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
CMAKE_GENERATOR: Ninja
|
CMAKE_GENERATOR: Ninja
|
||||||
|
|
||||||
|
go_mod_tidy:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: check that 'go mod tidy' is clean
|
||||||
|
run: go mod tidy --diff || (echo "Please run 'go mod tidy'." && exit 1)
|
||||||
|
|
||||||
test:
|
test:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
@@ -147,15 +154,82 @@ jobs:
|
|||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
env:
|
env:
|
||||||
CGO_ENABLED: '1'
|
CGO_ENABLED: '1'
|
||||||
|
GOEXPERIMENT: 'synctest'
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- name: checkout
|
||||||
- uses: actions/setup-go@v5
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2
|
||||||
|
|
||||||
|
- name: cache restore
|
||||||
|
uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
|
||||||
with:
|
with:
|
||||||
|
# Note: unlike the other setups, this is only grabbing the mod download
|
||||||
|
# cache, rather than the whole mod directory, as the download cache
|
||||||
|
# contains zips that can be unpacked in parallel faster than they can be
|
||||||
|
# fetched and extracted by tar
|
||||||
|
path: |
|
||||||
|
~/.cache/go-build
|
||||||
|
~/go/pkg/mod/cache
|
||||||
|
~\AppData\Local\go-build
|
||||||
|
# NOTE: The -3- here should be incremented when the scheme of data to be
|
||||||
|
# cached changes (e.g. path above changes).
|
||||||
|
key: ${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-${{ hashFiles('**/go.sum') }}-${{ github.run_id }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-${{ hashFiles('**/go.sum') }}
|
||||||
|
${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-
|
||||||
|
|
||||||
|
- name: Setup Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
# The caching strategy of setup-go is less than ideal, and wastes
|
||||||
|
# time by not saving artifacts due to small failures like the linter
|
||||||
|
# complaining, etc. This means subsequent have to rebuild their world
|
||||||
|
# again until all checks pass. For instance, if you mispell a word,
|
||||||
|
# you're punished until you fix it. This is more hostile than
|
||||||
|
# helpful.
|
||||||
|
cache: false
|
||||||
|
|
||||||
go-version-file: go.mod
|
go-version-file: go.mod
|
||||||
|
|
||||||
|
# It is tempting to run this in a platform independent way, but the past
|
||||||
|
# shows this codebase will see introductions of platform specific code
|
||||||
|
# generation, and so we need to check this per platform to ensure we
|
||||||
|
# don't abuse go generate on specific platforms.
|
||||||
|
- name: check that 'go generate' is clean
|
||||||
|
if: always()
|
||||||
|
run: |
|
||||||
|
go generate ./...
|
||||||
|
git diff --name-only --exit-code || (echo "Please run 'go generate ./...'." && exit 1)
|
||||||
|
|
||||||
|
- name: go test
|
||||||
|
if: always()
|
||||||
|
run: go test -count=1 -benchtime=1x ./...
|
||||||
|
|
||||||
|
# TODO(bmizerany): replace this heavy tool with just the
|
||||||
|
# tools/checks/binaries we want and then make them all run in parallel
|
||||||
|
# across jobs, not on a single tiny vm on Github Actions.
|
||||||
- uses: golangci/golangci-lint-action@v6
|
- uses: golangci/golangci-lint-action@v6
|
||||||
with:
|
with:
|
||||||
args: --timeout 10m0s -v
|
args: --timeout 10m0s -v
|
||||||
- run: go test ./...
|
|
||||||
|
- name: cache save
|
||||||
|
# Always save the cache, even if the job fails. The artifacts produced
|
||||||
|
# during the building of test binaries are not all for naught. They can
|
||||||
|
# be used to speed up subsequent runs.
|
||||||
|
if: always()
|
||||||
|
|
||||||
|
uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
|
||||||
|
with:
|
||||||
|
# Note: unlike the other setups, this is only grabbing the mod download
|
||||||
|
# cache, rather than the whole mod directory, as the download cache
|
||||||
|
# contains zips that can be unpacked in parallel faster than they can be
|
||||||
|
# fetched and extracted by tar
|
||||||
|
path: |
|
||||||
|
~/.cache/go-build
|
||||||
|
~/go/pkg/mod/cache
|
||||||
|
~\AppData\Local\go-build
|
||||||
|
# NOTE: The -3- here should be incremented when the scheme of data to be
|
||||||
|
# cached changes (e.g. path above changes).
|
||||||
|
key: ${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-${{ hashFiles('**/go.sum') }}-${{ github.run_id }}
|
||||||
|
|
||||||
patches:
|
patches:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
@@ -163,5 +237,5 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Verify patches apply cleanly and do not change files
|
- name: Verify patches apply cleanly and do not change files
|
||||||
run: |
|
run: |
|
||||||
make -f Makefile.sync clean checkout sync
|
make -f Makefile.sync clean sync
|
||||||
git diff --compact-summary --exit-code
|
git diff --compact-summary --exit-code
|
||||||
|
|||||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -5,7 +5,6 @@
|
|||||||
.swp
|
.swp
|
||||||
dist
|
dist
|
||||||
build
|
build
|
||||||
ollama
|
|
||||||
.cache
|
.cache
|
||||||
*.exe
|
*.exe
|
||||||
.idea
|
.idea
|
||||||
@@ -14,3 +13,4 @@ test_data
|
|||||||
__debug_bin*
|
__debug_bin*
|
||||||
llama/build
|
llama/build
|
||||||
llama/vendor
|
llama/vendor
|
||||||
|
/ollama
|
||||||
|
|||||||
@@ -6,8 +6,6 @@ linters:
|
|||||||
- bidichk
|
- bidichk
|
||||||
- bodyclose
|
- bodyclose
|
||||||
- containedctx
|
- containedctx
|
||||||
- contextcheck
|
|
||||||
- errcheck
|
|
||||||
- gocheckcompilerdirectives
|
- gocheckcompilerdirectives
|
||||||
- gofmt
|
- gofmt
|
||||||
- gofumpt
|
- gofumpt
|
||||||
@@ -23,10 +21,11 @@ linters:
|
|||||||
- staticcheck
|
- staticcheck
|
||||||
- tenv
|
- tenv
|
||||||
- unconvert
|
- unconvert
|
||||||
- unused
|
|
||||||
- usestdlibvars
|
|
||||||
- wastedassign
|
- wastedassign
|
||||||
- whitespace
|
- whitespace
|
||||||
|
disable:
|
||||||
|
- usestdlibvars
|
||||||
|
- errcheck
|
||||||
linters-settings:
|
linters-settings:
|
||||||
staticcheck:
|
staticcheck:
|
||||||
checks:
|
checks:
|
||||||
@@ -39,5 +38,4 @@ severity:
|
|||||||
- gofmt
|
- gofmt
|
||||||
- goimports
|
- goimports
|
||||||
- intrange
|
- intrange
|
||||||
- usestdlibvars
|
|
||||||
severity: info
|
severity: info
|
||||||
|
|||||||
@@ -23,8 +23,9 @@ set(GGML_SCHED_MAX_COPIES 4)
|
|||||||
set(GGML_LLAMAFILE ON)
|
set(GGML_LLAMAFILE ON)
|
||||||
set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128)
|
set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128)
|
||||||
set(GGML_CUDA_GRAPHS ON)
|
set(GGML_CUDA_GRAPHS ON)
|
||||||
|
set(GGML_CUDA_FA ON)
|
||||||
|
|
||||||
if((NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
|
if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
|
||||||
OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_SYSTEM_PROCESSOR MATCHES "arm|aarch64|ARM64|ARMv[0-9]+"))
|
OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_SYSTEM_PROCESSOR MATCHES "arm|aarch64|ARM64|ARMv[0-9]+"))
|
||||||
set(GGML_CPU_ALL_VARIANTS ON)
|
set(GGML_CPU_ALL_VARIANTS ON)
|
||||||
endif()
|
endif()
|
||||||
@@ -85,6 +86,11 @@ if(CMAKE_CUDA_COMPILER)
|
|||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX "^gfx(906|908|90a):xnack[+-]$"
|
||||||
|
CACHE STRING
|
||||||
|
"Regular expression describing AMDGPU_TARGETS not supported on Windows. Override to force building these targets. Default \"^gfx(906|908|90a):xnack[+-]$\"."
|
||||||
|
)
|
||||||
|
|
||||||
check_language(HIP)
|
check_language(HIP)
|
||||||
if(CMAKE_HIP_COMPILER)
|
if(CMAKE_HIP_COMPILER)
|
||||||
set(HIP_PLATFORM "amd")
|
set(HIP_PLATFORM "amd")
|
||||||
@@ -92,15 +98,24 @@ if(CMAKE_HIP_COMPILER)
|
|||||||
find_package(hip REQUIRED)
|
find_package(hip REQUIRED)
|
||||||
if(NOT AMDGPU_TARGETS)
|
if(NOT AMDGPU_TARGETS)
|
||||||
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012])$")
|
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012])$")
|
||||||
|
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
|
||||||
|
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(AMDGPU_TARGETS)
|
if(AMDGPU_TARGETS)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
||||||
|
|
||||||
|
if (WIN32)
|
||||||
|
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
target_compile_definitions(ggml-hip PRIVATE GGML_HIP_NO_VMM)
|
||||||
|
|
||||||
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
|
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
|
||||||
install(TARGETS ggml-hip
|
install(TARGETS ggml-hip
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
|
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
|
||||||
PRE_INCLUDE_REGEXES amdhip64 hipblas rocblas amd_comgr hsa_runtime64 rocprofiler-register drm_amdgpu drm numa
|
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf
|
||||||
PRE_EXCLUDE_REGEXES ".*"
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
POST_EXCLUDE_REGEXES "system32"
|
POST_EXCLUDE_REGEXES "system32"
|
||||||
RUNTIME DESTINATION ${OLLAMA_HIP_INSTALL_DIR} COMPONENT HIP
|
RUNTIME DESTINATION ${OLLAMA_HIP_INSTALL_DIR} COMPONENT HIP
|
||||||
|
|||||||
@@ -21,14 +21,14 @@
|
|||||||
"name": "CUDA 11",
|
"name": "CUDA 11",
|
||||||
"inherits": [ "CUDA" ],
|
"inherits": [ "CUDA" ],
|
||||||
"cacheVariables": {
|
"cacheVariables": {
|
||||||
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;62;70;72;75;80;86"
|
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;70;75;80;86"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "CUDA 12",
|
"name": "CUDA 12",
|
||||||
"inherits": [ "CUDA" ],
|
"inherits": [ "CUDA" ],
|
||||||
"cacheVariables": {
|
"cacheVariables": {
|
||||||
"CMAKE_CUDA_ARCHITECTURES": "60;61;62;70;72;75;80;86;87;89;90;90a"
|
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -56,7 +56,7 @@
|
|||||||
"name": "ROCm 6",
|
"name": "ROCm 6",
|
||||||
"inherits": [ "ROCm" ],
|
"inherits": [ "ROCm" ],
|
||||||
"cacheVariables": {
|
"cacheVariables": {
|
||||||
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102"
|
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -6,8 +6,6 @@ Thank you for your interest in contributing to Ollama! Here are a few guidelines
|
|||||||
|
|
||||||
See the [development documentation](./docs/development.md) for instructions on how to build and run Ollama locally.
|
See the [development documentation](./docs/development.md) for instructions on how to build and run Ollama locally.
|
||||||
|
|
||||||
## Pull requests
|
|
||||||
|
|
||||||
### Ideal issues
|
### Ideal issues
|
||||||
|
|
||||||
* [Bugs](https://github.com/ollama/ollama/issues?q=is%3Aissue+is%3Aopen+label%3Abug): issues where Ollama stops working or where it results in an unexpected error.
|
* [Bugs](https://github.com/ollama/ollama/issues?q=is%3Aissue+is%3Aopen+label%3Abug): issues where Ollama stops working or where it results in an unexpected error.
|
||||||
@@ -26,11 +24,64 @@ See the [development documentation](./docs/development.md) for instructions on h
|
|||||||
* Changes that add significant friction to the user experience
|
* Changes that add significant friction to the user experience
|
||||||
* Changes that create a large future maintenance burden for maintainers and contributors
|
* Changes that create a large future maintenance burden for maintainers and contributors
|
||||||
|
|
||||||
### Best practices
|
## Proposing a (non-trivial) change
|
||||||
|
|
||||||
* Commit messages: please leave both a title and a description in your commit messages. The title should be a short summary of the changes, with a leading word that explains the section of the code being changed (e.g. `api: fix parsing of prompt field`) . In the description, leave a short 2-3 sentences that explain more about the change and its impact.
|
> By "non-trivial", we mean a change that is not a bug fix or small
|
||||||
* Tests: please add test coverage to changes where possible.
|
> documentation update. If you are unsure, please ask us on our [Discord
|
||||||
* Minimize dependencies: avoid adding new dependencies unless absolutely necessary.
|
> server](https://discord.gg/ollama).
|
||||||
|
|
||||||
|
Before opening a non-trivial Pull Request, please open an issue to discuss the change and
|
||||||
|
get feedback from the maintainers. This helps us understand the context of the
|
||||||
|
change and how it fits into Ollama's roadmap and prevents us from duplicating
|
||||||
|
work or you from spending time on a change that we may not be able to accept.
|
||||||
|
|
||||||
|
Tips for proposals:
|
||||||
|
|
||||||
|
* Explain the problem you are trying to solve, not what you are trying to do.
|
||||||
|
* Explain why the change is important.
|
||||||
|
* Explain how the change will be used.
|
||||||
|
* Explain how the change will be tested.
|
||||||
|
|
||||||
|
Additionally, for bonus points: Provide draft documentation you would expect to
|
||||||
|
see if the change were accepted.
|
||||||
|
|
||||||
|
## Pull requests
|
||||||
|
|
||||||
|
**Commit messages**
|
||||||
|
|
||||||
|
The title should look like:
|
||||||
|
|
||||||
|
<package>: <short description>
|
||||||
|
|
||||||
|
The package is the most affected Go package. If the change does not affect Go
|
||||||
|
code, then use the directory name instead. Changes to a single well-known
|
||||||
|
file in the root directory may use the file name.
|
||||||
|
|
||||||
|
The short description should start with a lowercase letter and be a
|
||||||
|
continuation of the sentence:
|
||||||
|
|
||||||
|
"This changes Ollama to..."
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
llm/backend/mlx: support the llama architecture
|
||||||
|
CONTRIBUTING: provide clairity on good commit messages, and bad
|
||||||
|
|
||||||
|
Bad Examples:
|
||||||
|
|
||||||
|
feat: add more emoji
|
||||||
|
fix: was not using famous web framework
|
||||||
|
chore: generify code
|
||||||
|
|
||||||
|
**Tests**
|
||||||
|
|
||||||
|
Please include tests. Strive to test behavior, not implementation.
|
||||||
|
|
||||||
|
**New dependencies**
|
||||||
|
|
||||||
|
Dependencies should be added sparingly. If you are adding a new dependency,
|
||||||
|
please explain why it is necessary and what other ways you attempted that
|
||||||
|
did not work without it.
|
||||||
|
|
||||||
## Need help?
|
## Need help?
|
||||||
|
|
||||||
|
|||||||
43
Dockerfile
43
Dockerfile
@@ -2,22 +2,24 @@
|
|||||||
|
|
||||||
ARG FLAVOR=${TARGETARCH}
|
ARG FLAVOR=${TARGETARCH}
|
||||||
|
|
||||||
ARG ROCMVERSION=6.1.2
|
ARG ROCMVERSION=6.3.3
|
||||||
ARG JETPACK5VERSION=r35.4.1
|
ARG JETPACK5VERSION=r35.4.1
|
||||||
ARG JETPACK6VERSION=r36.2.0
|
ARG JETPACK6VERSION=r36.4.0
|
||||||
ARG CMAKEVERSION=3.31.2
|
ARG CMAKEVERSION=3.31.2
|
||||||
|
|
||||||
FROM --platform=linux/amd64 rocm/dev-centos-7:${ROCMVERSION}-complete AS base-amd64
|
# CUDA v11 requires gcc v10. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
|
||||||
RUN sed -i -e 's/mirror.centos.org/vault.centos.org/g' -e 's/^#.*baseurl=http/baseurl=http/g' -e 's/^mirrorlist=http/#mirrorlist=http/g' /etc/yum.repos.d/*.repo \
|
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
|
||||||
&& yum install -y yum-utils devtoolset-10-gcc devtoolset-10-gcc-c++ \
|
RUN yum install -y yum-utils \
|
||||||
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo \
|
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
|
||||||
&& curl -s -L https://github.com/ccache/ccache/releases/download/v4.10.2/ccache-4.10.2-linux-x86_64.tar.xz | tar -Jx -C /usr/local/bin --strip-components 1
|
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
|
||||||
ENV PATH=/opt/rh/devtoolset-10/root/usr/bin:/opt/rh/devtoolset-11/root/usr/bin:$PATH
|
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \
|
||||||
|
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
||||||
|
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
|
||||||
|
|
||||||
FROM --platform=linux/arm64 rockylinux:8 AS base-arm64
|
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
|
||||||
# install epel-release for ccache
|
# install epel-release for ccache
|
||||||
RUN yum install -y yum-utils epel-release \
|
RUN yum install -y yum-utils epel-release \
|
||||||
&& yum install -y clang ccache \
|
&& dnf install -y clang ccache \
|
||||||
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
|
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
|
||||||
ENV CC=clang CXX=clang++
|
ENV CC=clang CXX=clang++
|
||||||
|
|
||||||
@@ -29,9 +31,8 @@ COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
|||||||
ENV LDFLAGS=-s
|
ENV LDFLAGS=-s
|
||||||
|
|
||||||
FROM base AS cpu
|
FROM base AS cpu
|
||||||
# amd64 uses gcc which requires devtoolset-11 for AVX extensions while arm64 uses clang
|
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
|
||||||
RUN if [ "$(uname -m)" = "x86_64" ]; then yum install -y devtoolset-11-gcc devtoolset-11-gcc-c++; fi
|
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
||||||
ENV PATH=/opt/rh/devtoolset-11/root/usr/bin:$PATH
|
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CPU' \
|
cmake --preset 'CPU' \
|
||||||
&& cmake --build --parallel --preset 'CPU' \
|
&& cmake --build --parallel --preset 'CPU' \
|
||||||
@@ -39,7 +40,7 @@ RUN --mount=type=cache,target=/root/.ccache \
|
|||||||
|
|
||||||
FROM base AS cuda-11
|
FROM base AS cuda-11
|
||||||
ARG CUDA11VERSION=11.3
|
ARG CUDA11VERSION=11.3
|
||||||
RUN yum install -y cuda-toolkit-${CUDA11VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 11' \
|
cmake --preset 'CUDA 11' \
|
||||||
@@ -47,8 +48,8 @@ RUN --mount=type=cache,target=/root/.ccache \
|
|||||||
&& cmake --install build --component CUDA --strip --parallel 8
|
&& cmake --install build --component CUDA --strip --parallel 8
|
||||||
|
|
||||||
FROM base AS cuda-12
|
FROM base AS cuda-12
|
||||||
ARG CUDA12VERSION=12.4
|
ARG CUDA12VERSION=12.8
|
||||||
RUN yum install -y cuda-toolkit-${CUDA12VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 12' \
|
cmake --preset 'CUDA 12' \
|
||||||
@@ -56,6 +57,7 @@ RUN --mount=type=cache,target=/root/.ccache \
|
|||||||
&& cmake --install build --component CUDA --strip --parallel 8
|
&& cmake --install build --component CUDA --strip --parallel 8
|
||||||
|
|
||||||
FROM base AS rocm-6
|
FROM base AS rocm-6
|
||||||
|
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'ROCm 6' \
|
cmake --preset 'ROCm 6' \
|
||||||
&& cmake --build --parallel --preset 'ROCm 6' \
|
&& cmake --build --parallel --preset 'ROCm 6' \
|
||||||
@@ -84,10 +86,11 @@ RUN --mount=type=cache,target=/root/.ccache \
|
|||||||
&& cmake --install build --component CUDA --strip --parallel 8
|
&& cmake --install build --component CUDA --strip --parallel 8
|
||||||
|
|
||||||
FROM base AS build
|
FROM base AS build
|
||||||
ARG GOVERSION=1.23.4
|
|
||||||
RUN curl -fsSL https://golang.org/dl/go${GOVERSION}.linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
|
||||||
ENV PATH=/usr/local/go/bin:$PATH
|
|
||||||
WORKDIR /go/src/github.com/ollama/ollama
|
WORKDIR /go/src/github.com/ollama/ollama
|
||||||
|
COPY go.mod go.sum .
|
||||||
|
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||||
|
ENV PATH=/usr/local/go/bin:$PATH
|
||||||
|
RUN go mod download
|
||||||
COPY . .
|
COPY . .
|
||||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||||
ENV CGO_ENABLED=1
|
ENV CGO_ENABLED=1
|
||||||
@@ -104,7 +107,7 @@ COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
|
|||||||
COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 lib/ollama/cuda_jetpack5
|
COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 lib/ollama/cuda_jetpack5
|
||||||
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 lib/ollama/cuda_jetpack6
|
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 lib/ollama/cuda_jetpack6
|
||||||
|
|
||||||
FROM --platform=linux/arm64 scratch AS rocm
|
FROM scratch AS rocm
|
||||||
COPY --from=rocm-6 dist/lib/ollama/rocm /lib/ollama/rocm
|
COPY --from=rocm-6 dist/lib/ollama/rocm /lib/ollama/rocm
|
||||||
|
|
||||||
FROM ${FLAVOR} AS archive
|
FROM ${FLAVOR} AS archive
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
UPSTREAM=https://github.com/ggerganov/llama.cpp.git
|
UPSTREAM=https://github.com/ggerganov/llama.cpp.git
|
||||||
WORKDIR=llama/vendor
|
WORKDIR=llama/vendor
|
||||||
FETCH_HEAD=46e3556e01b824e52395fb050b29804b6cff2a7c
|
FETCH_HEAD=d7cfe1ffe0f435d0048a6058d529daf76e072d9c
|
||||||
|
|
||||||
.PHONY: help
|
.PHONY: help
|
||||||
help:
|
help:
|
||||||
@@ -15,7 +15,11 @@ help:
|
|||||||
@echo " make -f $(lastword $(MAKEFILE_LIST)) clean sync"
|
@echo " make -f $(lastword $(MAKEFILE_LIST)) clean sync"
|
||||||
|
|
||||||
.PHONY: sync
|
.PHONY: sync
|
||||||
sync: llama/llama.cpp ml/backend/ggml/ggml apply-patches
|
sync: llama/build-info.cpp llama/llama.cpp ml/backend/ggml/ggml apply-patches
|
||||||
|
|
||||||
|
.PHONY: llama/build-info.cpp
|
||||||
|
llama/build-info.cpp: llama/build-info.cpp.in
|
||||||
|
sed -e 's|@FETCH_HEAD@|$(FETCH_HEAD)|' $< > $@
|
||||||
|
|
||||||
.PHONY: llama/llama.cpp
|
.PHONY: llama/llama.cpp
|
||||||
llama/llama.cpp: llama/vendor/ apply-patches
|
llama/llama.cpp: llama/vendor/ apply-patches
|
||||||
|
|||||||
94
README.md
94
README.md
@@ -1,5 +1,5 @@
|
|||||||
<div align="center">
|
<div align="center">
|
||||||
<a href="https://ollama.com" />
|
<a href="https://ollama.com">
|
||||||
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
@@ -18,7 +18,7 @@ Get up and running with large language models.
|
|||||||
|
|
||||||
### Linux
|
### Linux
|
||||||
|
|
||||||
```
|
```shell
|
||||||
curl -fsSL https://ollama.com/install.sh | sh
|
curl -fsSL https://ollama.com/install.sh | sh
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -42,7 +42,7 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
|
|||||||
|
|
||||||
To run and chat with [Llama 3.2](https://ollama.com/library/llama3.2):
|
To run and chat with [Llama 3.2](https://ollama.com/library/llama3.2):
|
||||||
|
|
||||||
```
|
```shell
|
||||||
ollama run llama3.2
|
ollama run llama3.2
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -54,6 +54,13 @@ Here are some example models that can be downloaded:
|
|||||||
|
|
||||||
| Model | Parameters | Size | Download |
|
| Model | Parameters | Size | Download |
|
||||||
| ------------------ | ---------- | ----- | -------------------------------- |
|
| ------------------ | ---------- | ----- | -------------------------------- |
|
||||||
|
| Gemma 3 | 1B | 815MB | `ollama run gemma3:1b` |
|
||||||
|
| Gemma 3 | 4B | 3.3GB | `ollama run gemma3` |
|
||||||
|
| Gemma 3 | 12B | 8.1GB | `ollama run gemma3:12b` |
|
||||||
|
| Gemma 3 | 27B | 17GB | `ollama run gemma3:27b` |
|
||||||
|
| QwQ | 32B | 20GB | `ollama run qwq` |
|
||||||
|
| DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` |
|
||||||
|
| DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` |
|
||||||
| Llama 3.3 | 70B | 43GB | `ollama run llama3.3` |
|
| Llama 3.3 | 70B | 43GB | `ollama run llama3.3` |
|
||||||
| Llama 3.2 | 3B | 2.0GB | `ollama run llama3.2` |
|
| Llama 3.2 | 3B | 2.0GB | `ollama run llama3.2` |
|
||||||
| Llama 3.2 | 1B | 1.3GB | `ollama run llama3.2:1b` |
|
| Llama 3.2 | 1B | 1.3GB | `ollama run llama3.2:1b` |
|
||||||
@@ -62,10 +69,7 @@ Here are some example models that can be downloaded:
|
|||||||
| Llama 3.1 | 8B | 4.7GB | `ollama run llama3.1` |
|
| Llama 3.1 | 8B | 4.7GB | `ollama run llama3.1` |
|
||||||
| Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` |
|
| Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` |
|
||||||
| Phi 4 | 14B | 9.1GB | `ollama run phi4` |
|
| Phi 4 | 14B | 9.1GB | `ollama run phi4` |
|
||||||
| Phi 3 Mini | 3.8B | 2.3GB | `ollama run phi3` |
|
| Phi 4 Mini | 3.8B | 2.5GB | `ollama run phi4-mini` |
|
||||||
| Gemma 2 | 2B | 1.6GB | `ollama run gemma2:2b` |
|
|
||||||
| Gemma 2 | 9B | 5.5GB | `ollama run gemma2` |
|
|
||||||
| Gemma 2 | 27B | 16GB | `ollama run gemma2:27b` |
|
|
||||||
| Mistral | 7B | 4.1GB | `ollama run mistral` |
|
| Mistral | 7B | 4.1GB | `ollama run mistral` |
|
||||||
| Moondream 2 | 1.4B | 829MB | `ollama run moondream` |
|
| Moondream 2 | 1.4B | 829MB | `ollama run moondream` |
|
||||||
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
|
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
|
||||||
@@ -73,7 +77,7 @@ Here are some example models that can be downloaded:
|
|||||||
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
||||||
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
||||||
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
||||||
| Solar | 10.7B | 6.1GB | `ollama run solar` |
|
| Granite-3.2 | 8B | 4.9GB | `ollama run granite3.2` |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
||||||
@@ -92,13 +96,13 @@ Ollama supports importing GGUF models in the Modelfile:
|
|||||||
|
|
||||||
2. Create the model in Ollama
|
2. Create the model in Ollama
|
||||||
|
|
||||||
```
|
```shell
|
||||||
ollama create example -f Modelfile
|
ollama create example -f Modelfile
|
||||||
```
|
```
|
||||||
|
|
||||||
3. Run the model
|
3. Run the model
|
||||||
|
|
||||||
```
|
```shell
|
||||||
ollama run example
|
ollama run example
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -110,7 +114,7 @@ See the [guide](docs/import.md) on importing models for more information.
|
|||||||
|
|
||||||
Models from the Ollama library can be customized with a prompt. For example, to customize the `llama3.2` model:
|
Models from the Ollama library can be customized with a prompt. For example, to customize the `llama3.2` model:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
ollama pull llama3.2
|
ollama pull llama3.2
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -145,13 +149,13 @@ For more information on working with a Modelfile, see the [Modelfile](docs/model
|
|||||||
|
|
||||||
`ollama create` is used to create a model from a Modelfile.
|
`ollama create` is used to create a model from a Modelfile.
|
||||||
|
|
||||||
```
|
```shell
|
||||||
ollama create mymodel -f ./Modelfile
|
ollama create mymodel -f ./Modelfile
|
||||||
```
|
```
|
||||||
|
|
||||||
### Pull a model
|
### Pull a model
|
||||||
|
|
||||||
```
|
```shell
|
||||||
ollama pull llama3.2
|
ollama pull llama3.2
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -159,13 +163,13 @@ ollama pull llama3.2
|
|||||||
|
|
||||||
### Remove a model
|
### Remove a model
|
||||||
|
|
||||||
```
|
```shell
|
||||||
ollama rm llama3.2
|
ollama rm llama3.2
|
||||||
```
|
```
|
||||||
|
|
||||||
### Copy a model
|
### Copy a model
|
||||||
|
|
||||||
```
|
```shell
|
||||||
ollama cp llama3.2 my-model
|
ollama cp llama3.2 my-model
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -184,37 +188,39 @@ I'm a basic program that prints the famous "Hello, world!" message to the consol
|
|||||||
|
|
||||||
```
|
```
|
||||||
ollama run llava "What's in this image? /Users/jmorgan/Desktop/smile.png"
|
ollama run llava "What's in this image? /Users/jmorgan/Desktop/smile.png"
|
||||||
The image features a yellow smiley face, which is likely the central focus of the picture.
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> **Output**: The image features a yellow smiley face, which is likely the central focus of the picture.
|
||||||
|
|
||||||
### Pass the prompt as an argument
|
### Pass the prompt as an argument
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama run llama3.2 "Summarize this file: $(cat README.md)"
|
||||||
```
|
```
|
||||||
$ ollama run llama3.2 "Summarize this file: $(cat README.md)"
|
|
||||||
Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
|
> **Output**: Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
|
||||||
```
|
|
||||||
|
|
||||||
### Show model information
|
### Show model information
|
||||||
|
|
||||||
```
|
```shell
|
||||||
ollama show llama3.2
|
ollama show llama3.2
|
||||||
```
|
```
|
||||||
|
|
||||||
### List models on your computer
|
### List models on your computer
|
||||||
|
|
||||||
```
|
```shell
|
||||||
ollama list
|
ollama list
|
||||||
```
|
```
|
||||||
|
|
||||||
### List which models are currently loaded
|
### List which models are currently loaded
|
||||||
|
|
||||||
```
|
```shell
|
||||||
ollama ps
|
ollama ps
|
||||||
```
|
```
|
||||||
|
|
||||||
### Stop a model which is currently running
|
### Stop a model which is currently running
|
||||||
|
|
||||||
```
|
```shell
|
||||||
ollama stop llama3.2
|
ollama stop llama3.2
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -230,13 +236,13 @@ See the [developer guide](https://github.com/ollama/ollama/blob/main/docs/develo
|
|||||||
|
|
||||||
Next, start the server:
|
Next, start the server:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
./ollama serve
|
./ollama serve
|
||||||
```
|
```
|
||||||
|
|
||||||
Finally, in a separate shell, run a model:
|
Finally, in a separate shell, run a model:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
./ollama run llama3.2
|
./ollama run llama3.2
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -246,7 +252,7 @@ Ollama has a REST API for running and managing models.
|
|||||||
|
|
||||||
### Generate a response
|
### Generate a response
|
||||||
|
|
||||||
```
|
```shell
|
||||||
curl http://localhost:11434/api/generate -d '{
|
curl http://localhost:11434/api/generate -d '{
|
||||||
"model": "llama3.2",
|
"model": "llama3.2",
|
||||||
"prompt":"Why is the sky blue?"
|
"prompt":"Why is the sky blue?"
|
||||||
@@ -255,7 +261,7 @@ curl http://localhost:11434/api/generate -d '{
|
|||||||
|
|
||||||
### Chat with a model
|
### Chat with a model
|
||||||
|
|
||||||
```
|
```shell
|
||||||
curl http://localhost:11434/api/chat -d '{
|
curl http://localhost:11434/api/chat -d '{
|
||||||
"model": "llama3.2",
|
"model": "llama3.2",
|
||||||
"messages": [
|
"messages": [
|
||||||
@@ -271,6 +277,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
### Web & Desktop
|
### Web & Desktop
|
||||||
|
|
||||||
- [Open WebUI](https://github.com/open-webui/open-webui)
|
- [Open WebUI](https://github.com/open-webui/open-webui)
|
||||||
|
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
|
||||||
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
||||||
- [Hollama](https://github.com/fmaclen/hollama)
|
- [Hollama](https://github.com/fmaclen/hollama)
|
||||||
- [Lollms-Webui](https://github.com/ParisNeo/lollms-webui)
|
- [Lollms-Webui](https://github.com/ParisNeo/lollms-webui)
|
||||||
@@ -353,6 +360,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
|
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
|
||||||
- [Promptery](https://github.com/promptery/promptery) (desktop client for Ollama.)
|
- [Promptery](https://github.com/promptery/promptery) (desktop client for Ollama.)
|
||||||
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
||||||
|
- [chat-ollama](https://github.com/annilq/chat-ollama) (a React Native client for Ollama)
|
||||||
- [SpaceLlama](https://github.com/tcsenpai/spacellama) (Firefox and Chrome extension to quickly summarize web pages with ollama in a sidebar)
|
- [SpaceLlama](https://github.com/tcsenpai/spacellama) (Firefox and Chrome extension to quickly summarize web pages with ollama in a sidebar)
|
||||||
- [YouLama](https://github.com/tcsenpai/youlama) (Webapp to quickly summarize any YouTube video, supporting Invidious as well)
|
- [YouLama](https://github.com/tcsenpai/youlama) (Webapp to quickly summarize any YouTube video, supporting Invidious as well)
|
||||||
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
|
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
|
||||||
@@ -369,9 +377,23 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Minima](https://github.com/dmayboroda/minima) (RAG with on-premises or fully local workflow)
|
- [Minima](https://github.com/dmayboroda/minima) (RAG with on-premises or fully local workflow)
|
||||||
- [aidful-ollama-model-delete](https://github.com/AidfulAI/aidful-ollama-model-delete) (User interface for simplified model cleanup)
|
- [aidful-ollama-model-delete](https://github.com/AidfulAI/aidful-ollama-model-delete) (User interface for simplified model cleanup)
|
||||||
- [Perplexica](https://github.com/ItzCrazyKns/Perplexica) (An AI-powered search engine & an open-source alternative to Perplexity AI)
|
- [Perplexica](https://github.com/ItzCrazyKns/Perplexica) (An AI-powered search engine & an open-source alternative to Perplexity AI)
|
||||||
|
- [Ollama Chat WebUI for Docker ](https://github.com/oslook/ollama-webui) (Support for local docker deployment, lightweight ollama webui)
|
||||||
- [AI Toolkit for Visual Studio Code](https://aka.ms/ai-tooklit/ollama-docs) (Microsoft-official VSCode extension to chat, test, evaluate models with Ollama support, and use them in your AI applications.)
|
- [AI Toolkit for Visual Studio Code](https://aka.ms/ai-tooklit/ollama-docs) (Microsoft-official VSCode extension to chat, test, evaluate models with Ollama support, and use them in your AI applications.)
|
||||||
- [MinimalNextOllamaChat](https://github.com/anilkay/MinimalNextOllamaChat) (Minimal Web UI for Chat and Model Control)
|
- [MinimalNextOllamaChat](https://github.com/anilkay/MinimalNextOllamaChat) (Minimal Web UI for Chat and Model Control)
|
||||||
- [Chipper](https://github.com/TilmanGriesel/chipper) AI interface for tinkerers (Ollama, Haystack RAG, Python)
|
- [Chipper](https://github.com/TilmanGriesel/chipper) AI interface for tinkerers (Ollama, Haystack RAG, Python)
|
||||||
|
- [ChibiChat](https://github.com/CosmicEventHorizon/ChibiChat) (Kotlin-based Android app to chat with Ollama and Koboldcpp API endpoints)
|
||||||
|
- [LocalLLM](https://github.com/qusaismael/localllm) (Minimal Web-App to run ollama models on it with a GUI)
|
||||||
|
- [Ollamazing](https://github.com/buiducnhat/ollamazing) (Web extension to run Ollama models)
|
||||||
|
- [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivent endpoint with Ollama support for running locally)
|
||||||
|
- [AntSK](https://github.com/AIDotNet/AntSK) (Out-of-the-box & Adaptable RAG Chatbot)
|
||||||
|
- [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot)
|
||||||
|
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
|
||||||
|
- [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms)
|
||||||
|
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
|
||||||
|
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
|
||||||
|
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
||||||
|
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
|
||||||
|
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
|
||||||
|
|
||||||
### Cloud
|
### Cloud
|
||||||
|
|
||||||
@@ -415,6 +437,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
|
|
||||||
### Apple Vision Pro
|
### Apple Vision Pro
|
||||||
|
|
||||||
|
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Cross-platform AI chat app supporting Apple Vision Pro via "Designed for iPad")
|
||||||
- [Enchanted](https://github.com/AugustDev/enchanted)
|
- [Enchanted](https://github.com/AugustDev/enchanted)
|
||||||
|
|
||||||
### Database
|
### Database
|
||||||
@@ -429,9 +452,10 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
|
|
||||||
- [Pacman](https://archlinux.org/packages/extra/x86_64/ollama/)
|
- [Pacman](https://archlinux.org/packages/extra/x86_64/ollama/)
|
||||||
- [Gentoo](https://github.com/gentoo/guru/tree/master/app-misc/ollama)
|
- [Gentoo](https://github.com/gentoo/guru/tree/master/app-misc/ollama)
|
||||||
|
- [Homebrew](https://formulae.brew.sh/formula/ollama)
|
||||||
- [Helm Chart](https://artifacthub.io/packages/helm/ollama-helm/ollama)
|
- [Helm Chart](https://artifacthub.io/packages/helm/ollama-helm/ollama)
|
||||||
- [Guix channel](https://codeberg.org/tusharhero/ollama-guix)
|
- [Guix channel](https://codeberg.org/tusharhero/ollama-guix)
|
||||||
- [Nix package](https://search.nixos.org/packages?channel=24.05&show=ollama&from=0&size=50&sort=relevance&type=packages&query=ollama)
|
- [Nix package](https://search.nixos.org/packages?show=ollama&from=0&size=50&sort=relevance&type=packages&query=ollama)
|
||||||
- [Flox](https://flox.dev/blog/ollama-part-one)
|
- [Flox](https://flox.dev/blog/ollama-part-one)
|
||||||
|
|
||||||
### Libraries
|
### Libraries
|
||||||
@@ -485,13 +509,20 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
|
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
|
||||||
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in unified API)
|
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in unified API)
|
||||||
- [LlmTornado](https://github.com/lofcz/llmtornado) (C# library providing a unified interface for major FOSS & Commercial inference APIs)
|
- [LlmTornado](https://github.com/lofcz/llmtornado) (C# library providing a unified interface for major FOSS & Commercial inference APIs)
|
||||||
|
- [Ollama for Zig](https://github.com/dravenk/ollama-zig)
|
||||||
|
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
|
||||||
|
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
|
||||||
|
- [Ollama for D](https://github.com/kassane/ollama-d)
|
||||||
|
|
||||||
### Mobile
|
### Mobile
|
||||||
|
|
||||||
|
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Lightning-fast Cross-platform AI chat app with native UI for Android, iOS and iPad)
|
||||||
- [Enchanted](https://github.com/AugustDev/enchanted)
|
- [Enchanted](https://github.com/AugustDev/enchanted)
|
||||||
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid)
|
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid)
|
||||||
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
||||||
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
|
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
|
||||||
|
- [Ollama Android Chat](https://github.com/sunshine0523/OllamaServer) (No need for Termux, start the Ollama service with one click on an Android device)
|
||||||
|
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
||||||
|
|
||||||
### Extensions & Plugins
|
### Extensions & Plugins
|
||||||
|
|
||||||
@@ -535,13 +566,18 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [TextCraft](https://github.com/suncloudsmoon/TextCraft) (Copilot in Word alternative using Ollama)
|
- [TextCraft](https://github.com/suncloudsmoon/TextCraft) (Copilot in Word alternative using Ollama)
|
||||||
- [Alfred Ollama](https://github.com/zeitlings/alfred-ollama) (Alfred Workflow)
|
- [Alfred Ollama](https://github.com/zeitlings/alfred-ollama) (Alfred Workflow)
|
||||||
- [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language
|
- [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language
|
||||||
|
- [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai)
|
||||||
|
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
|
||||||
|
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
|
||||||
|
|
||||||
### Supported backends
|
### Supported backends
|
||||||
|
|
||||||
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
|
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
|
||||||
|
|
||||||
### Observability
|
### Observability
|
||||||
|
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native intergration to Ollama.
|
||||||
|
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
|
||||||
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
||||||
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.
|
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.
|
||||||
- [Langfuse](https://langfuse.com/docs/integrations/ollama) is an open source LLM observability platform that enables teams to collaboratively monitor, evaluate and debug AI applications.
|
- [Langfuse](https://langfuse.com/docs/integrations/ollama) is an open source LLM observability platform that enables teams to collaboratively monitor, evaluate and debug AI applications.
|
||||||
|
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
|
||||||
|
|||||||
@@ -10,7 +10,7 @@
|
|||||||
// repository].
|
// repository].
|
||||||
//
|
//
|
||||||
// [the API documentation]: https://github.com/ollama/ollama/blob/main/docs/api.md
|
// [the API documentation]: https://github.com/ollama/ollama/blob/main/docs/api.md
|
||||||
// [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/examples
|
// [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/api/examples
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -132,7 +132,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
|||||||
const maxBufferSize = 512 * format.KiloByte
|
const maxBufferSize = 512 * format.KiloByte
|
||||||
|
|
||||||
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
||||||
var buf *bytes.Buffer
|
var buf io.Reader
|
||||||
if data != nil {
|
if data != nil {
|
||||||
bts, err := json.Marshal(data)
|
bts, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,6 +1,13 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -43,3 +50,206 @@ func TestClientFromEnvironment(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// testError represents an internal error type with status code and message
|
||||||
|
// this is used since the error response from the server is not a standard error struct
|
||||||
|
type testError struct {
|
||||||
|
message string
|
||||||
|
statusCode int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e testError) Error() string {
|
||||||
|
return e.message
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientStream(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
responses []any
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "immediate error response",
|
||||||
|
responses: []any{
|
||||||
|
testError{
|
||||||
|
message: "test error message",
|
||||||
|
statusCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "test error message",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error after successful chunks, ok response",
|
||||||
|
responses: []any{
|
||||||
|
ChatResponse{Message: Message{Content: "partial response 1"}},
|
||||||
|
ChatResponse{Message: Message{Content: "partial response 2"}},
|
||||||
|
testError{
|
||||||
|
message: "mid-stream error",
|
||||||
|
statusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "mid-stream error",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "successful stream completion",
|
||||||
|
responses: []any{
|
||||||
|
ChatResponse{Message: Message{Content: "chunk 1"}},
|
||||||
|
ChatResponse{Message: Message{Content: "chunk 2"}},
|
||||||
|
ChatResponse{
|
||||||
|
Message: Message{Content: "final chunk"},
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "stop",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
flusher, ok := w.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected http.Flusher")
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||||
|
|
||||||
|
for _, resp := range tc.responses {
|
||||||
|
if errResp, ok := resp.(testError); ok {
|
||||||
|
w.WriteHeader(errResp.statusCode)
|
||||||
|
err := json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"error": errResp.message,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("failed to encode error response:", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||||
|
t.Fatalf("failed to encode response: %v", err)
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
|
||||||
|
|
||||||
|
var receivedChunks []ChatResponse
|
||||||
|
err := client.stream(context.Background(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error {
|
||||||
|
var resp ChatResponse
|
||||||
|
if err := json.Unmarshal(chunk, &resp); err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal chunk: %w", err)
|
||||||
|
}
|
||||||
|
receivedChunks = append(receivedChunks, resp)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if tc.wantErr != "" {
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error but got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), tc.wantErr) {
|
||||||
|
t.Errorf("expected error containing %q, got %v", tc.wantErr, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientDo(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
response any
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "immediate error response",
|
||||||
|
response: testError{
|
||||||
|
message: "test error message",
|
||||||
|
statusCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
wantErr: "test error message",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "server error response",
|
||||||
|
response: testError{
|
||||||
|
message: "internal error",
|
||||||
|
statusCode: http.StatusInternalServerError,
|
||||||
|
},
|
||||||
|
wantErr: "internal error",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "successful response",
|
||||||
|
response: struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Success bool `json:"success"`
|
||||||
|
}{
|
||||||
|
ID: "msg_123",
|
||||||
|
Success: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if errResp, ok := tc.response.(testError); ok {
|
||||||
|
w.WriteHeader(errResp.statusCode)
|
||||||
|
err := json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"error": errResp.message,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("failed to encode error response:", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if err := json.NewEncoder(w).Encode(tc.response); err != nil {
|
||||||
|
t.Fatalf("failed to encode response: %v", err)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Success bool `json:"success"`
|
||||||
|
}
|
||||||
|
err := client.do(context.Background(), http.MethodPost, "/v1/messages", nil, &resp)
|
||||||
|
|
||||||
|
if tc.wantErr != "" {
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("got nil, want error %q", tc.wantErr)
|
||||||
|
}
|
||||||
|
if err.Error() != tc.wantErr {
|
||||||
|
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("got error %q, want nil", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if expectedResp, ok := tc.response.(struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Success bool `json:"success"`
|
||||||
|
}); ok {
|
||||||
|
if resp.ID != expectedResp.ID {
|
||||||
|
t.Errorf("response ID mismatch: got %q, want %q", resp.ID, expectedResp.ID)
|
||||||
|
}
|
||||||
|
if resp.Success != expectedResp.Success {
|
||||||
|
t.Errorf("response Success mismatch: got %v, want %v", resp.Success, expectedResp.Success)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,9 +2,10 @@
|
|||||||
|
|
||||||
Run the examples in this directory with:
|
Run the examples in this directory with:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
go run example_name/main.go
|
go run example_name/main.go
|
||||||
```
|
```
|
||||||
|
|
||||||
## Chat - Chat with a model
|
## Chat - Chat with a model
|
||||||
- [chat/main.go](chat/main.go)
|
- [chat/main.go](chat/main.go)
|
||||||
|
|
||||||
|
|||||||
18
api/types.go
18
api/types.go
@@ -10,6 +10,8 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StatusError is an error with an HTTP status code and message.
|
// StatusError is an error with an HTTP status code and message.
|
||||||
@@ -347,6 +349,7 @@ type ShowResponse struct {
|
|||||||
Messages []Message `json:"messages,omitempty"`
|
Messages []Message `json:"messages,omitempty"`
|
||||||
ModelInfo map[string]any `json:"model_info,omitempty"`
|
ModelInfo map[string]any `json:"model_info,omitempty"`
|
||||||
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
||||||
|
Tensors []Tensor `json:"tensors,omitempty"`
|
||||||
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -359,9 +362,9 @@ type CopyRequest struct {
|
|||||||
// PullRequest is the request passed to [Client.Pull].
|
// PullRequest is the request passed to [Client.Pull].
|
||||||
type PullRequest struct {
|
type PullRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Insecure bool `json:"insecure,omitempty"`
|
Insecure bool `json:"insecure,omitempty"` // Deprecated: ignored
|
||||||
Username string `json:"username"`
|
Username string `json:"username"` // Deprecated: ignored
|
||||||
Password string `json:"password"`
|
Password string `json:"password"` // Deprecated: ignored
|
||||||
Stream *bool `json:"stream,omitempty"`
|
Stream *bool `json:"stream,omitempty"`
|
||||||
|
|
||||||
// Deprecated: set the model name with Model instead
|
// Deprecated: set the model name with Model instead
|
||||||
@@ -465,6 +468,13 @@ type ModelDetails struct {
|
|||||||
QuantizationLevel string `json:"quantization_level"`
|
QuantizationLevel string `json:"quantization_level"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tensor describes the metadata for a given tensor.
|
||||||
|
type Tensor struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Shape []uint64 `json:"shape"`
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Metrics) Summary() {
|
func (m *Metrics) Summary() {
|
||||||
if m.TotalDuration > 0 {
|
if m.TotalDuration > 0 {
|
||||||
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
|
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
|
||||||
@@ -609,7 +619,7 @@ func DefaultOptions() Options {
|
|||||||
|
|
||||||
Runner: Runner{
|
Runner: Runner{
|
||||||
// options set when the model is loaded
|
// options set when the model is loaded
|
||||||
NumCtx: 2048,
|
NumCtx: int(envconfig.ContextLength()),
|
||||||
NumBatch: 512,
|
NumBatch: 512,
|
||||||
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
|
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
|
||||||
NumThread: 0, // let the runtime decide
|
NumThread: 0, // let the runtime decide
|
||||||
|
|||||||
@@ -17,6 +17,6 @@ If you want to build the installer, youll need to install
|
|||||||
In the top directory of this repo, run the following powershell script
|
In the top directory of this repo, run the following powershell script
|
||||||
to build the ollama CLI, ollama app, and ollama installer.
|
to build the ollama CLI, ollama app, and ollama installer.
|
||||||
|
|
||||||
```
|
```powershell
|
||||||
powershell -ExecutionPolicy Bypass -File .\scripts\build_windows.ps1
|
powershell -ExecutionPolicy Bypass -File .\scripts\build_windows.ps1
|
||||||
```
|
```
|
||||||
|
|||||||
178
benchmark/server_benchmark_test.go
Normal file
178
benchmark/server_benchmark_test.go
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
package benchmark
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Command line flags
|
||||||
|
var modelFlag string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
|
||||||
|
flag.Lookup("m").DefValue = "model"
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelName returns the model name from flags, failing the test if not set
|
||||||
|
func modelName(b *testing.B) string {
|
||||||
|
if modelFlag == "" {
|
||||||
|
b.Fatal("Error: -m flag is required for benchmark tests")
|
||||||
|
}
|
||||||
|
return modelFlag
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestCase struct {
|
||||||
|
name string
|
||||||
|
prompt string
|
||||||
|
maxTokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
// runGenerateBenchmark contains the common generate and metrics logic
|
||||||
|
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
|
||||||
|
start := time.Now()
|
||||||
|
var ttft time.Duration
|
||||||
|
var metrics api.Metrics
|
||||||
|
|
||||||
|
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||||
|
if ttft == 0 && resp.Response != "" {
|
||||||
|
ttft = time.Since(start)
|
||||||
|
}
|
||||||
|
if resp.Done {
|
||||||
|
metrics = resp.Metrics
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Report custom metrics as part of the benchmark results
|
||||||
|
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
|
||||||
|
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
|
||||||
|
|
||||||
|
// Token throughput metrics
|
||||||
|
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
|
||||||
|
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
|
||||||
|
b.ReportMetric(promptThroughput, "prompt_tok/s")
|
||||||
|
b.ReportMetric(genThroughput, "gen_tok/s")
|
||||||
|
|
||||||
|
// Token counts
|
||||||
|
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
|
||||||
|
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkColdStart runs benchmarks with model loading from cold state
|
||||||
|
func BenchmarkColdStart(b *testing.B) {
|
||||||
|
client := setup(b)
|
||||||
|
tests := []TestCase{
|
||||||
|
{"short_prompt", "Write a long story", 100},
|
||||||
|
{"medium_prompt", "Write a detailed economic analysis", 500},
|
||||||
|
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
||||||
|
}
|
||||||
|
m := modelName(b)
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Set number of tokens as our throughput metric
|
||||||
|
b.SetBytes(int64(tt.maxTokens))
|
||||||
|
|
||||||
|
for b.Loop() {
|
||||||
|
b.StopTimer()
|
||||||
|
// Ensure model is unloaded before each iteration
|
||||||
|
unload(client, m, b)
|
||||||
|
b.StartTimer()
|
||||||
|
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: m,
|
||||||
|
Prompt: tt.prompt,
|
||||||
|
Options: map[string]interface{}{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||||
|
}
|
||||||
|
|
||||||
|
runGenerateBenchmark(b, ctx, client, req)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkWarmStart runs benchmarks with pre-loaded model
|
||||||
|
func BenchmarkWarmStart(b *testing.B) {
|
||||||
|
client := setup(b)
|
||||||
|
tests := []TestCase{
|
||||||
|
{"short_prompt", "Write a long story", 100},
|
||||||
|
{"medium_prompt", "Write a detailed economic analysis", 500},
|
||||||
|
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
||||||
|
}
|
||||||
|
m := modelName(b)
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Pre-warm the model
|
||||||
|
warmup(client, m, tt.prompt, b)
|
||||||
|
|
||||||
|
// Set number of tokens as our throughput metric
|
||||||
|
b.SetBytes(int64(tt.maxTokens))
|
||||||
|
|
||||||
|
for b.Loop() {
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: m,
|
||||||
|
Prompt: tt.prompt,
|
||||||
|
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||||
|
}
|
||||||
|
|
||||||
|
runGenerateBenchmark(b, ctx, client, req)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setup verifies server and model availability
|
||||||
|
func setup(b *testing.B) *api.Client {
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName(b)}); err != nil {
|
||||||
|
b.Fatalf("Model unavailable: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
// warmup ensures the model is loaded and warmed up
|
||||||
|
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
|
||||||
|
for range 3 {
|
||||||
|
err := client.Generate(
|
||||||
|
context.Background(),
|
||||||
|
&api.GenerateRequest{
|
||||||
|
Model: model,
|
||||||
|
Prompt: prompt,
|
||||||
|
Options: map[string]interface{}{"num_predict": 50, "temperature": 0.1},
|
||||||
|
},
|
||||||
|
func(api.GenerateResponse) error { return nil },
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
b.Logf("Error during model warm-up: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// unload forces model unloading using KeepAlive: 0 parameter
|
||||||
|
func unload(client *api.Client, model string, b *testing.B) {
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: model,
|
||||||
|
KeepAlive: &api.Duration{Duration: 0},
|
||||||
|
}
|
||||||
|
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
|
||||||
|
b.Logf("Unload error: %v", err)
|
||||||
|
}
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
}
|
||||||
68
cmd/cmd.go
68
cmd/cmd.go
@@ -18,6 +18,7 @@ import (
|
|||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -34,10 +35,9 @@ import (
|
|||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/llama"
|
|
||||||
"github.com/ollama/ollama/llama/runner"
|
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/progress"
|
"github.com/ollama/ollama/progress"
|
||||||
|
"github.com/ollama/ollama/runner"
|
||||||
"github.com/ollama/ollama/server"
|
"github.com/ollama/ollama/server"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
@@ -256,6 +256,7 @@ func StopHandler(cmd *cobra.Command, args []string) error {
|
|||||||
if strings.Contains(err.Error(), "not found") {
|
if strings.Contains(err.Error(), "not found") {
|
||||||
return fmt.Errorf("couldn't find model \"%s\" to stop", args[0])
|
return fmt.Errorf("couldn't find model \"%s\" to stop", args[0])
|
||||||
}
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -338,7 +339,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
opts.MultiModal = len(info.ProjectorInfo) != 0
|
if len(info.ProjectorInfo) != 0 {
|
||||||
|
opts.MultiModal = true
|
||||||
|
}
|
||||||
|
for k := range info.ModelInfo {
|
||||||
|
if strings.Contains(k, ".vision.") {
|
||||||
|
opts.MultiModal = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
opts.ParentModel = info.Details.ParentModel
|
opts.ParentModel = info.Details.ParentModel
|
||||||
|
|
||||||
if interactive {
|
if interactive {
|
||||||
@@ -559,8 +569,9 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
|
|||||||
parameters, errParams := cmd.Flags().GetBool("parameters")
|
parameters, errParams := cmd.Flags().GetBool("parameters")
|
||||||
system, errSystem := cmd.Flags().GetBool("system")
|
system, errSystem := cmd.Flags().GetBool("system")
|
||||||
template, errTemplate := cmd.Flags().GetBool("template")
|
template, errTemplate := cmd.Flags().GetBool("template")
|
||||||
|
verbose, errVerbose := cmd.Flags().GetBool("verbose")
|
||||||
|
|
||||||
for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate} {
|
for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate, errVerbose} {
|
||||||
if boolErr != nil {
|
if boolErr != nil {
|
||||||
return errors.New("error retrieving flags")
|
return errors.New("error retrieving flags")
|
||||||
}
|
}
|
||||||
@@ -598,7 +609,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
|
return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
|
||||||
}
|
}
|
||||||
|
|
||||||
req := api.ShowRequest{Name: args[0]}
|
req := api.ShowRequest{Name: args[0], Verbose: verbose}
|
||||||
resp, err := client.Show(cmd.Context(), &req)
|
resp, err := client.Show(cmd.Context(), &req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -621,10 +632,10 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return showInfo(resp, os.Stdout)
|
return showInfo(resp, verbose, os.Stdout)
|
||||||
}
|
}
|
||||||
|
|
||||||
func showInfo(resp *api.ShowResponse, w io.Writer) error {
|
func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
||||||
tableRender := func(header string, rows func() [][]string) {
|
tableRender := func(header string, rows func() [][]string) {
|
||||||
fmt.Fprintln(w, " ", header)
|
fmt.Fprintln(w, " ", header)
|
||||||
table := tablewriter.NewWriter(w)
|
table := tablewriter.NewWriter(w)
|
||||||
@@ -681,6 +692,47 @@ func showInfo(resp *api.ShowResponse, w io.Writer) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if resp.ModelInfo != nil && verbose {
|
||||||
|
tableRender("Metadata", func() (rows [][]string) {
|
||||||
|
keys := make([]string, 0, len(resp.ModelInfo))
|
||||||
|
for k := range resp.ModelInfo {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
|
||||||
|
for _, k := range keys {
|
||||||
|
var v string
|
||||||
|
switch vData := resp.ModelInfo[k].(type) {
|
||||||
|
case bool:
|
||||||
|
v = fmt.Sprintf("%t", vData)
|
||||||
|
case string:
|
||||||
|
v = vData
|
||||||
|
case float64:
|
||||||
|
v = fmt.Sprintf("%g", vData)
|
||||||
|
case []any:
|
||||||
|
n := 3
|
||||||
|
if len(vData) < n {
|
||||||
|
n = len(vData)
|
||||||
|
}
|
||||||
|
v = fmt.Sprintf("%v", vData[:n])
|
||||||
|
default:
|
||||||
|
v = fmt.Sprintf("%T", vData)
|
||||||
|
}
|
||||||
|
rows = append(rows, []string{"", k, v})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resp.Tensors) > 0 && verbose {
|
||||||
|
tableRender("Tensors", func() (rows [][]string) {
|
||||||
|
for _, t := range resp.Tensors {
|
||||||
|
rows = append(rows, []string{"", t.Name, t.Type, fmt.Sprint(t.Shape)})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
head := func(s string, n int) (rows [][]string) {
|
head := func(s string, n int) (rows [][]string) {
|
||||||
scanner := bufio.NewScanner(strings.NewReader(s))
|
scanner := bufio.NewScanner(strings.NewReader(s))
|
||||||
for scanner.Scan() && (len(rows) < n || n < 0) {
|
for scanner.Scan() && (len(rows) < n || n < 0) {
|
||||||
@@ -1187,6 +1239,7 @@ func NewCLI() *cobra.Command {
|
|||||||
showCmd.Flags().Bool("parameters", false, "Show parameters of a model")
|
showCmd.Flags().Bool("parameters", false, "Show parameters of a model")
|
||||||
showCmd.Flags().Bool("template", false, "Show template of a model")
|
showCmd.Flags().Bool("template", false, "Show template of a model")
|
||||||
showCmd.Flags().Bool("system", false, "Show system message of a model")
|
showCmd.Flags().Bool("system", false, "Show system message of a model")
|
||||||
|
showCmd.Flags().BoolP("verbose", "v", false, "Show detailed model information")
|
||||||
|
|
||||||
runCmd := &cobra.Command{
|
runCmd := &cobra.Command{
|
||||||
Use: "run MODEL [PROMPT]",
|
Use: "run MODEL [PROMPT]",
|
||||||
@@ -1271,7 +1324,6 @@ func NewCLI() *cobra.Command {
|
|||||||
|
|
||||||
runnerCmd := &cobra.Command{
|
runnerCmd := &cobra.Command{
|
||||||
Use: "runner",
|
Use: "runner",
|
||||||
Short: llama.PrintSystemInfo(),
|
|
||||||
Hidden: true,
|
Hidden: true,
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
return runner.Execute(os.Args[1:])
|
return runner.Execute(os.Args[1:])
|
||||||
|
|||||||
286
cmd/cmd_test.go
286
cmd/cmd_test.go
@@ -10,6 +10,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
@@ -26,7 +27,7 @@ func TestShowInfo(t *testing.T) {
|
|||||||
ParameterSize: "7B",
|
ParameterSize: "7B",
|
||||||
QuantizationLevel: "FP16",
|
QuantizationLevel: "FP16",
|
||||||
},
|
},
|
||||||
}, &b); err != nil {
|
}, false, &b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -56,7 +57,7 @@ func TestShowInfo(t *testing.T) {
|
|||||||
ParameterSize: "7B",
|
ParameterSize: "7B",
|
||||||
QuantizationLevel: "FP16",
|
QuantizationLevel: "FP16",
|
||||||
},
|
},
|
||||||
}, &b); err != nil {
|
}, false, &b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -67,6 +68,60 @@ func TestShowInfo(t *testing.T) {
|
|||||||
embedding length 0
|
embedding length 0
|
||||||
quantization FP16
|
quantization FP16
|
||||||
|
|
||||||
|
`
|
||||||
|
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||||
|
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("verbose model", func(t *testing.T) {
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := showInfo(&api.ShowResponse{
|
||||||
|
Details: api.ModelDetails{
|
||||||
|
Family: "test",
|
||||||
|
ParameterSize: "8B",
|
||||||
|
QuantizationLevel: "FP16",
|
||||||
|
},
|
||||||
|
Parameters: `
|
||||||
|
stop up`,
|
||||||
|
ModelInfo: map[string]any{
|
||||||
|
"general.architecture": "test",
|
||||||
|
"general.parameter_count": float64(8_000_000_000),
|
||||||
|
"some.true_bool": true,
|
||||||
|
"some.false_bool": false,
|
||||||
|
"test.context_length": float64(1000),
|
||||||
|
"test.embedding_length": float64(11434),
|
||||||
|
},
|
||||||
|
Tensors: []api.Tensor{
|
||||||
|
{Name: "blk.0.attn_k.weight", Type: "BF16", Shape: []uint64{42, 3117}},
|
||||||
|
{Name: "blk.0.attn_q.weight", Type: "FP16", Shape: []uint64{3117, 42}},
|
||||||
|
},
|
||||||
|
}, true, &b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expect := ` Model
|
||||||
|
architecture test
|
||||||
|
parameters 8B
|
||||||
|
context length 1000
|
||||||
|
embedding length 11434
|
||||||
|
quantization FP16
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
stop up
|
||||||
|
|
||||||
|
Metadata
|
||||||
|
general.architecture test
|
||||||
|
general.parameter_count 8e+09
|
||||||
|
some.false_bool false
|
||||||
|
some.true_bool true
|
||||||
|
test.context_length 1000
|
||||||
|
test.embedding_length 11434
|
||||||
|
|
||||||
|
Tensors
|
||||||
|
blk.0.attn_k.weight BF16 [42 3117]
|
||||||
|
blk.0.attn_q.weight FP16 [3117 42]
|
||||||
|
|
||||||
`
|
`
|
||||||
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||||
@@ -88,7 +143,7 @@ func TestShowInfo(t *testing.T) {
|
|||||||
stop you
|
stop you
|
||||||
stop up
|
stop up
|
||||||
temperature 99`,
|
temperature 99`,
|
||||||
}, &b); err != nil {
|
}, false, &b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,7 +180,7 @@ func TestShowInfo(t *testing.T) {
|
|||||||
"clip.vision.embedding_length": float64(0),
|
"clip.vision.embedding_length": float64(0),
|
||||||
"clip.vision.projection_dim": float64(0),
|
"clip.vision.projection_dim": float64(0),
|
||||||
},
|
},
|
||||||
}, &b); err != nil {
|
}, false, &b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -158,7 +213,7 @@ func TestShowInfo(t *testing.T) {
|
|||||||
Ahoy, matey!
|
Ahoy, matey!
|
||||||
Weigh anchor!
|
Weigh anchor!
|
||||||
`,
|
`,
|
||||||
}, &b); err != nil {
|
}, false, &b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -187,7 +242,7 @@ Weigh anchor!
|
|||||||
QuantizationLevel: "FP16",
|
QuantizationLevel: "FP16",
|
||||||
},
|
},
|
||||||
License: license,
|
License: license,
|
||||||
}, &b); err != nil {
|
}, false, &b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -490,6 +545,96 @@ func TestPushHandler(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestListHandler(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
serverResponse []api.ListModelResponse
|
||||||
|
expectedError string
|
||||||
|
expectedOutput string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "list all models",
|
||||||
|
args: []string{},
|
||||||
|
serverResponse: []api.ListModelResponse{
|
||||||
|
{Name: "model1", Digest: "sha256:abc123", Size: 1024, ModifiedAt: time.Now().Add(-24 * time.Hour)},
|
||||||
|
{Name: "model2", Digest: "sha256:def456", Size: 2048, ModifiedAt: time.Now().Add(-48 * time.Hour)},
|
||||||
|
},
|
||||||
|
expectedOutput: "NAME ID SIZE MODIFIED \n" +
|
||||||
|
"model1 sha256:abc12 1.0 KB 24 hours ago \n" +
|
||||||
|
"model2 sha256:def45 2.0 KB 2 days ago \n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filter models by prefix",
|
||||||
|
args: []string{"model1"},
|
||||||
|
serverResponse: []api.ListModelResponse{
|
||||||
|
{Name: "model1", Digest: "sha256:abc123", Size: 1024, ModifiedAt: time.Now().Add(-24 * time.Hour)},
|
||||||
|
{Name: "model2", Digest: "sha256:def456", Size: 2048, ModifiedAt: time.Now().Add(-24 * time.Hour)},
|
||||||
|
},
|
||||||
|
expectedOutput: "NAME ID SIZE MODIFIED \n" +
|
||||||
|
"model1 sha256:abc12 1.0 KB 24 hours ago \n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "server error",
|
||||||
|
args: []string{},
|
||||||
|
expectedError: "server error",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/api/tags" || r.Method != http.MethodGet {
|
||||||
|
t.Errorf("unexpected request to %s %s", r.Method, r.URL.Path)
|
||||||
|
http.Error(w, "not found", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.expectedError != "" {
|
||||||
|
http.Error(w, tt.expectedError, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response := api.ListResponse{Models: tt.serverResponse}
|
||||||
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer mockServer.Close()
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(context.TODO())
|
||||||
|
|
||||||
|
// Capture stdout
|
||||||
|
oldStdout := os.Stdout
|
||||||
|
r, w, _ := os.Pipe()
|
||||||
|
os.Stdout = w
|
||||||
|
|
||||||
|
err := ListHandler(cmd, tt.args)
|
||||||
|
|
||||||
|
// Restore stdout and get output
|
||||||
|
w.Close()
|
||||||
|
os.Stdout = oldStdout
|
||||||
|
output, _ := io.ReadAll(r)
|
||||||
|
|
||||||
|
if tt.expectedError == "" {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
if got := string(output); got != tt.expectedOutput {
|
||||||
|
t.Errorf("expected output:\n%s\ngot:\n%s", tt.expectedOutput, got)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
|
||||||
|
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCreateHandler(t *testing.T) {
|
func TestCreateHandler(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -616,3 +761,132 @@ func TestCreateHandler(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewCreateRequest(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
from string
|
||||||
|
opts runOptions
|
||||||
|
expected *api.CreateRequest
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"basic test",
|
||||||
|
"newmodel",
|
||||||
|
runOptions{
|
||||||
|
Model: "mymodel",
|
||||||
|
ParentModel: "",
|
||||||
|
Prompt: "You are a fun AI agent",
|
||||||
|
Messages: []api.Message{},
|
||||||
|
WordWrap: true,
|
||||||
|
},
|
||||||
|
&api.CreateRequest{
|
||||||
|
From: "mymodel",
|
||||||
|
Model: "newmodel",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"parent model test",
|
||||||
|
"newmodel",
|
||||||
|
runOptions{
|
||||||
|
Model: "mymodel",
|
||||||
|
ParentModel: "parentmodel",
|
||||||
|
Messages: []api.Message{},
|
||||||
|
WordWrap: true,
|
||||||
|
},
|
||||||
|
&api.CreateRequest{
|
||||||
|
From: "parentmodel",
|
||||||
|
Model: "newmodel",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"parent model as filepath test",
|
||||||
|
"newmodel",
|
||||||
|
runOptions{
|
||||||
|
Model: "mymodel",
|
||||||
|
ParentModel: "/some/file/like/etc/passwd",
|
||||||
|
Messages: []api.Message{},
|
||||||
|
WordWrap: true,
|
||||||
|
},
|
||||||
|
&api.CreateRequest{
|
||||||
|
From: "mymodel",
|
||||||
|
Model: "newmodel",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"parent model as windows filepath test",
|
||||||
|
"newmodel",
|
||||||
|
runOptions{
|
||||||
|
Model: "mymodel",
|
||||||
|
ParentModel: "D:\\some\\file\\like\\etc\\passwd",
|
||||||
|
Messages: []api.Message{},
|
||||||
|
WordWrap: true,
|
||||||
|
},
|
||||||
|
&api.CreateRequest{
|
||||||
|
From: "mymodel",
|
||||||
|
Model: "newmodel",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"options test",
|
||||||
|
"newmodel",
|
||||||
|
runOptions{
|
||||||
|
Model: "mymodel",
|
||||||
|
ParentModel: "parentmodel",
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
&api.CreateRequest{
|
||||||
|
From: "parentmodel",
|
||||||
|
Model: "newmodel",
|
||||||
|
Parameters: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"messages test",
|
||||||
|
"newmodel",
|
||||||
|
runOptions{
|
||||||
|
Model: "mymodel",
|
||||||
|
ParentModel: "parentmodel",
|
||||||
|
System: "You are a fun AI agent",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "hello there!",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "hello to you!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
WordWrap: true,
|
||||||
|
},
|
||||||
|
&api.CreateRequest{
|
||||||
|
From: "parentmodel",
|
||||||
|
Model: "newmodel",
|
||||||
|
System: "You are a fun AI agent",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "hello there!",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "hello to you!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
actual := NewCreateRequest(tt.from, tt.opts)
|
||||||
|
if !cmp.Equal(actual, tt.expected) {
|
||||||
|
t.Errorf("expected output %#v, got %#v", tt.expected, actual)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/readline"
|
"github.com/ollama/ollama/readline"
|
||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
type MultilineState int
|
type MultilineState int
|
||||||
@@ -195,6 +196,10 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
opts.Messages = []api.Message{}
|
opts.Messages = []api.Message{}
|
||||||
fmt.Printf("Loading model '%s'\n", opts.Model)
|
fmt.Printf("Loading model '%s'\n", opts.Model)
|
||||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||||
|
if strings.Contains(err.Error(), "not found") {
|
||||||
|
fmt.Printf("error: %v\n", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
@@ -343,7 +348,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
|
|
||||||
switch args[1] {
|
switch args[1] {
|
||||||
case "info":
|
case "info":
|
||||||
_ = showInfo(resp, os.Stderr)
|
_ = showInfo(resp, false, os.Stderr)
|
||||||
case "license":
|
case "license":
|
||||||
if resp.License == "" {
|
if resp.License == "" {
|
||||||
fmt.Println("No license was specified for this model.")
|
fmt.Println("No license was specified for this model.")
|
||||||
@@ -455,9 +460,16 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
|
func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
|
||||||
|
parentModel := opts.ParentModel
|
||||||
|
|
||||||
|
modelName := model.ParseName(parentModel)
|
||||||
|
if !modelName.IsValid() {
|
||||||
|
parentModel = ""
|
||||||
|
}
|
||||||
|
|
||||||
req := &api.CreateRequest{
|
req := &api.CreateRequest{
|
||||||
Name: name,
|
Model: name,
|
||||||
From: cmp.Or(opts.ParentModel, opts.Model),
|
From: cmp.Or(parentModel, opts.Model),
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.System != "" {
|
if opts.System != "" {
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llama/runner"
|
"github.com/ollama/ollama/runner"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|||||||
@@ -9,12 +9,17 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ModelParameters struct {
|
type ModelParameters struct {
|
||||||
Architectures []string `json:"architectures"`
|
Architectures []string `json:"architectures"`
|
||||||
VocabSize uint32 `json:"vocab_size"`
|
VocabSize uint32 `json:"vocab_size"`
|
||||||
|
TextModel TextParameters `json:"text_config"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TextParameters struct {
|
||||||
|
VocabSize uint32 `json:"vocab_size"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AdapterParameters struct {
|
type AdapterParameters struct {
|
||||||
@@ -27,8 +32,8 @@ type AdapterParameters struct {
|
|||||||
} `json:"lora_parameters"`
|
} `json:"lora_parameters"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ModelParameters) KV(t *Tokenizer) llm.KV {
|
func (ModelParameters) KV(t *Tokenizer) ggml.KV {
|
||||||
kv := llm.KV{
|
kv := ggml.KV{
|
||||||
"general.file_type": uint32(1),
|
"general.file_type": uint32(1),
|
||||||
"general.quantization_version": uint32(2),
|
"general.quantization_version": uint32(2),
|
||||||
"tokenizer.ggml.pre": t.Pre,
|
"tokenizer.ggml.pre": t.Pre,
|
||||||
@@ -54,7 +59,7 @@ func (ModelParameters) KV(t *Tokenizer) llm.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p AdapterParameters) KV() llm.KV {
|
func (p AdapterParameters) KV() ggml.KV {
|
||||||
var alpha float32
|
var alpha float32
|
||||||
if p.LoraParameters.Alpha == 0 {
|
if p.LoraParameters.Alpha == 0 {
|
||||||
alpha = float32(p.Alpha)
|
alpha = float32(p.Alpha)
|
||||||
@@ -62,7 +67,7 @@ func (p AdapterParameters) KV() llm.KV {
|
|||||||
alpha = p.LoraParameters.Alpha
|
alpha = p.LoraParameters.Alpha
|
||||||
}
|
}
|
||||||
|
|
||||||
kv := llm.KV{
|
kv := ggml.KV{
|
||||||
"adapter.lora.alpha": alpha,
|
"adapter.lora.alpha": alpha,
|
||||||
"adapter.type": "lora",
|
"adapter.type": "lora",
|
||||||
"general.file_type": uint32(1),
|
"general.file_type": uint32(1),
|
||||||
@@ -79,19 +84,19 @@ func (ModelParameters) specialTokenTypes() []string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ModelParameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor) error {
|
func (ModelParameters) writeFile(ws io.WriteSeeker, kv ggml.KV, ts []ggml.Tensor) error {
|
||||||
return llm.WriteGGUF(ws, kv, ts)
|
return ggml.WriteGGUF(ws, kv, ts)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (AdapterParameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor) error {
|
func (AdapterParameters) writeFile(ws io.WriteSeeker, kv ggml.KV, ts []ggml.Tensor) error {
|
||||||
return llm.WriteGGUF(ws, kv, ts)
|
return ggml.WriteGGUF(ws, kv, ts)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ModelConverter interface {
|
type ModelConverter interface {
|
||||||
// KV maps parameters to LLM key-values
|
// KV maps parameters to LLM key-values
|
||||||
KV(*Tokenizer) llm.KV
|
KV(*Tokenizer) ggml.KV
|
||||||
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
|
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
|
||||||
Tensors([]Tensor) []llm.Tensor
|
Tensors([]Tensor) []ggml.Tensor
|
||||||
// Replacements returns a list of string pairs to replace in tensor names.
|
// Replacements returns a list of string pairs to replace in tensor names.
|
||||||
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
|
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
|
||||||
Replacements() []string
|
Replacements() []string
|
||||||
@@ -99,7 +104,7 @@ type ModelConverter interface {
|
|||||||
// specialTokenTypes returns any special token types the model uses
|
// specialTokenTypes returns any special token types the model uses
|
||||||
specialTokenTypes() []string
|
specialTokenTypes() []string
|
||||||
// writeFile writes the model to the provided io.WriteSeeker
|
// writeFile writes the model to the provided io.WriteSeeker
|
||||||
writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error
|
writeFile(io.WriteSeeker, ggml.KV, []ggml.Tensor) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type moreParser interface {
|
type moreParser interface {
|
||||||
@@ -108,17 +113,17 @@ type moreParser interface {
|
|||||||
|
|
||||||
type AdapterConverter interface {
|
type AdapterConverter interface {
|
||||||
// KV maps parameters to LLM key-values
|
// KV maps parameters to LLM key-values
|
||||||
KV(llm.KV) llm.KV
|
KV(ggml.KV) ggml.KV
|
||||||
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
|
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
|
||||||
Tensors([]Tensor) []llm.Tensor
|
Tensors([]Tensor) []ggml.Tensor
|
||||||
// Replacements returns a list of string pairs to replace in tensor names.
|
// Replacements returns a list of string pairs to replace in tensor names.
|
||||||
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
|
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
|
||||||
Replacements() []string
|
Replacements() []string
|
||||||
|
|
||||||
writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error
|
writeFile(io.WriteSeeker, ggml.KV, []ggml.Tensor) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV llm.KV) error {
|
func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV ggml.KV) error {
|
||||||
bts, err := fs.ReadFile(fsys, "adapter_config.json")
|
bts, err := fs.ReadFile(fsys, "adapter_config.json")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -185,6 +190,8 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
|||||||
conv = &gemmaModel{}
|
conv = &gemmaModel{}
|
||||||
case "Gemma2ForCausalLM":
|
case "Gemma2ForCausalLM":
|
||||||
conv = &gemma2Model{}
|
conv = &gemma2Model{}
|
||||||
|
case "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration":
|
||||||
|
conv = &gemma3Model{Architecture: p.Architectures[0]}
|
||||||
case "Phi3ForCausalLM":
|
case "Phi3ForCausalLM":
|
||||||
conv = &phi3Model{}
|
conv = &phi3Model{}
|
||||||
case "Qwen2ForCausalLM":
|
case "Qwen2ForCausalLM":
|
||||||
@@ -194,7 +201,7 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
|||||||
case "CohereForCausalLM":
|
case "CohereForCausalLM":
|
||||||
conv = &commandrModel{}
|
conv = &commandrModel{}
|
||||||
default:
|
default:
|
||||||
return errors.New("unsupported architecture")
|
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(bts, conv); err != nil {
|
if err := json.Unmarshal(bts, conv); err != nil {
|
||||||
@@ -213,7 +220,14 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
vocabSize := int(p.VocabSize)
|
vocabSize := int(p.VocabSize)
|
||||||
|
if vocabSize == 0 {
|
||||||
|
tVocabSize := int(p.TextModel.VocabSize)
|
||||||
|
vocabSize = tVocabSize
|
||||||
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
|
case vocabSize == 0:
|
||||||
|
slog.Warn("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens))
|
||||||
case vocabSize > len(t.Vocabulary.Tokens):
|
case vocabSize > len(t.Vocabulary.Tokens):
|
||||||
slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
|
slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
|
||||||
for i := range vocabSize - len(t.Vocabulary.Tokens) {
|
for i := range vocabSize - len(t.Vocabulary.Tokens) {
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type bertModel struct {
|
type bertModel struct {
|
||||||
@@ -85,7 +85,7 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *bertModel) KV(t *Tokenizer) llm.KV {
|
func (p *bertModel) KV(t *Tokenizer) ggml.KV {
|
||||||
kv := p.ModelParameters.KV(t)
|
kv := p.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "bert"
|
kv["general.architecture"] = "bert"
|
||||||
kv["bert.attention.causal"] = false
|
kv["bert.attention.causal"] = false
|
||||||
@@ -132,8 +132,8 @@ func (p *bertModel) KV(t *Tokenizer) llm.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *bertModel) Tensors(ts []Tensor) []llm.Tensor {
|
func (p *bertModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||||
var out []llm.Tensor
|
var out []ggml.Tensor
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
if slices.Contains([]string{
|
if slices.Contains([]string{
|
||||||
"embeddings.position_ids",
|
"embeddings.position_ids",
|
||||||
@@ -143,7 +143,7 @@ func (p *bertModel) Tensors(ts []Tensor) []llm.Tensor {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
out = append(out, llm.Tensor{
|
out = append(out, ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: t.Shape(),
|
Shape: t.Shape(),
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package convert
|
|||||||
import (
|
import (
|
||||||
"cmp"
|
"cmp"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type commandrModel struct {
|
type commandrModel struct {
|
||||||
@@ -24,7 +24,7 @@ type commandrModel struct {
|
|||||||
|
|
||||||
var _ ModelConverter = (*commandrModel)(nil)
|
var _ ModelConverter = (*commandrModel)(nil)
|
||||||
|
|
||||||
func (p *commandrModel) KV(t *Tokenizer) llm.KV {
|
func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
|
||||||
kv := p.ModelParameters.KV(t)
|
kv := p.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "command-r"
|
kv["general.architecture"] = "command-r"
|
||||||
kv["general.name"] = "command-r"
|
kv["general.name"] = "command-r"
|
||||||
@@ -43,10 +43,10 @@ func (p *commandrModel) KV(t *Tokenizer) llm.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *commandrModel) Tensors(ts []Tensor) []llm.Tensor {
|
func (p *commandrModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||||
var out []llm.Tensor
|
var out []ggml.Tensor
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
out = append(out, llm.Tensor{
|
out = append(out, ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: t.Shape(),
|
Shape: t.Shape(),
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"github.com/pdevine/tensor"
|
"github.com/pdevine/tensor"
|
||||||
"github.com/pdevine/tensor/native"
|
"github.com/pdevine/tensor/native"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type gemmaModel struct {
|
type gemmaModel struct {
|
||||||
@@ -23,7 +23,7 @@ type gemmaModel struct {
|
|||||||
|
|
||||||
var _ ModelConverter = (*gemmaModel)(nil)
|
var _ ModelConverter = (*gemmaModel)(nil)
|
||||||
|
|
||||||
func (p *gemmaModel) KV(t *Tokenizer) llm.KV {
|
func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
|
||||||
kv := p.ModelParameters.KV(t)
|
kv := p.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "gemma"
|
kv["general.architecture"] = "gemma"
|
||||||
kv["gemma.context_length"] = p.MaxPositionEmbeddings
|
kv["gemma.context_length"] = p.MaxPositionEmbeddings
|
||||||
@@ -42,14 +42,14 @@ func (p *gemmaModel) KV(t *Tokenizer) llm.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *gemmaModel) Tensors(ts []Tensor) []llm.Tensor {
|
func (p *gemmaModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||||
var out []llm.Tensor
|
var out []ggml.Tensor
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
if strings.HasSuffix(t.Name(), "_norm.weight") {
|
if !strings.HasPrefix(t.Name(), "v.") && strings.HasSuffix(t.Name(), "_norm.weight") {
|
||||||
t.SetRepacker(p.addOne)
|
t.SetRepacker(p.addOne)
|
||||||
}
|
}
|
||||||
|
|
||||||
out = append(out, llm.Tensor{
|
out = append(out, ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: t.Shape(),
|
Shape: t.Shape(),
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
package convert
|
package convert
|
||||||
|
|
||||||
import (
|
import "github.com/ollama/ollama/fs/ggml"
|
||||||
"github.com/ollama/ollama/llm"
|
|
||||||
)
|
|
||||||
|
|
||||||
type gemma2Model struct {
|
type gemma2Model struct {
|
||||||
gemmaModel
|
gemmaModel
|
||||||
@@ -11,7 +9,7 @@ type gemma2Model struct {
|
|||||||
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *gemma2Model) KV(t *Tokenizer) llm.KV {
|
func (p *gemma2Model) KV(t *Tokenizer) ggml.KV {
|
||||||
kv := p.ModelParameters.KV(t)
|
kv := p.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "gemma2"
|
kv["general.architecture"] = "gemma2"
|
||||||
kv["gemma2.context_length"] = p.MaxPositionEmbeddings
|
kv["gemma2.context_length"] = p.MaxPositionEmbeddings
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"github.com/pdevine/tensor"
|
"github.com/pdevine/tensor"
|
||||||
"github.com/pdevine/tensor/native"
|
"github.com/pdevine/tensor/native"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type gemma2Adapter struct {
|
type gemma2Adapter struct {
|
||||||
@@ -15,14 +15,14 @@ type gemma2Adapter struct {
|
|||||||
|
|
||||||
var _ AdapterConverter = (*gemma2Adapter)(nil)
|
var _ AdapterConverter = (*gemma2Adapter)(nil)
|
||||||
|
|
||||||
func (p *gemma2Adapter) KV(baseKV llm.KV) llm.KV {
|
func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
|
||||||
kv := p.AdapterParameters.KV()
|
kv := p.AdapterParameters.KV()
|
||||||
kv["general.architecture"] = "gemma2"
|
kv["general.architecture"] = "gemma2"
|
||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *gemma2Adapter) Tensors(ts []Tensor) []llm.Tensor {
|
func (p *gemma2Adapter) Tensors(ts []Tensor) []ggml.Tensor {
|
||||||
var out []llm.Tensor
|
var out []ggml.Tensor
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
shape := t.Shape()
|
shape := t.Shape()
|
||||||
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
|
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
|
||||||
@@ -31,7 +31,7 @@ func (p *gemma2Adapter) Tensors(ts []Tensor) []llm.Tensor {
|
|||||||
t.SetRepacker(p.repack)
|
t.SetRepacker(p.repack)
|
||||||
}
|
}
|
||||||
|
|
||||||
out = append(out, llm.Tensor{
|
out = append(out, ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: t.Shape(),
|
Shape: t.Shape(),
|
||||||
|
|||||||
142
convert/convert_gemma3.go
Normal file
142
convert/convert_gemma3.go
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type gemma3Model struct {
|
||||||
|
gemmaModel
|
||||||
|
Architecture string
|
||||||
|
TextModel struct {
|
||||||
|
HeadDim uint32 `json:"head_dim"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
SlidingWindow uint32 `json:"sliding_window"`
|
||||||
|
} `json:"text_config"`
|
||||||
|
VisionModel struct {
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"` // attention.head_count 16
|
||||||
|
LayerNormEpsilon float32 `json:"layer_norm_eps"` // attention.layer_norm_epsilon 1e-05
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"` // block_count 32
|
||||||
|
HiddenSize uint32 `json:"hidden_size"` // embedding_length 1280
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"` // feed_forward_length 5120
|
||||||
|
ImageSize uint32 `json:"image_size"` // image_size 560
|
||||||
|
NumChannels uint32 `json:"num_channels"` // num_channels 3
|
||||||
|
PatchSize uint32 `json:"patch_size"` // patch_size 14
|
||||||
|
} `json:"vision_config"`
|
||||||
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
|
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||||
|
HeadDim uint32 `json:"head_dim"`
|
||||||
|
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
||||||
|
RopeLocalTheta float32 `json:"rope_local_base_freq"`
|
||||||
|
RopeGlobalTheta float32 `json:"rope_global_base_freq"`
|
||||||
|
SlidingWindow uint32 `json:"sliding_window"`
|
||||||
|
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
gemma4BLayerCount = 34
|
||||||
|
gemma12BLayerCount = 48
|
||||||
|
gemma27BLayerCount = 62
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
|
||||||
|
kv := p.ModelParameters.KV(t)
|
||||||
|
kv["general.architecture"] = "gemma3"
|
||||||
|
|
||||||
|
numBlocks := cmp.Or(p.HiddenLayers, p.TextModel.HiddenLayers)
|
||||||
|
kv["gemma3.block_count"] = numBlocks
|
||||||
|
|
||||||
|
var (
|
||||||
|
numHeads uint32
|
||||||
|
numKVHeads uint32
|
||||||
|
)
|
||||||
|
|
||||||
|
switch numBlocks {
|
||||||
|
case gemma4BLayerCount:
|
||||||
|
numHeads = 8
|
||||||
|
numKVHeads = 4
|
||||||
|
case gemma12BLayerCount:
|
||||||
|
numHeads = 16
|
||||||
|
numKVHeads = 8
|
||||||
|
case gemma27BLayerCount:
|
||||||
|
numHeads = 32
|
||||||
|
numKVHeads = 16
|
||||||
|
default:
|
||||||
|
numHeads = p.NumAttentionHeads
|
||||||
|
numKVHeads = p.NumKeyValueHeads
|
||||||
|
}
|
||||||
|
|
||||||
|
kv["gemma3.attention.head_count"] = numHeads
|
||||||
|
kv["gemma3.attention.head_count_kv"] = numKVHeads
|
||||||
|
|
||||||
|
switch p.Architecture {
|
||||||
|
case "Gemma3ForCausalLM":
|
||||||
|
kv["gemma3.context_length"] = p.MaxPositionEmbeddings
|
||||||
|
kv["gemma3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||||
|
kv["gemma3.attention.key_length"] = p.HeadDim
|
||||||
|
kv["gemma3.attention.value_length"] = p.HeadDim
|
||||||
|
kv["gemma3.attention.sliding_window"] = p.SlidingWindow
|
||||||
|
kv["gemma3.final_logit_softcapping"] = cmp.Or(p.FinalLogitSoftcap, 30)
|
||||||
|
kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0)
|
||||||
|
kv["gemma3.rope.global.freq_base"] = cmp.Or(p.RopeGlobalTheta, 1000000.0)
|
||||||
|
kv["gemma3.embedding_length"] = p.HiddenSize
|
||||||
|
kv["gemma3.feed_forward_length"] = p.IntermediateSize
|
||||||
|
default:
|
||||||
|
kv["gemma3.context_length"] = cmp.Or(p.MaxPositionEmbeddings, 131072)
|
||||||
|
kv["gemma3.embedding_length"] = p.TextModel.HiddenSize
|
||||||
|
kv["gemma3.feed_forward_length"] = p.TextModel.IntermediateSize
|
||||||
|
kv["gemma3.attention.sliding_window"] = p.TextModel.SlidingWindow
|
||||||
|
kv["gemma3.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
||||||
|
kv["gemma3.vision.embedding_length"] = p.VisionModel.HiddenSize
|
||||||
|
kv["gemma3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
|
||||||
|
kv["gemma3.vision.image_size"] = p.VisionModel.ImageSize
|
||||||
|
kv["gemma3.vision.patch_size"] = p.VisionModel.PatchSize
|
||||||
|
kv["gemma3.vision.num_channels"] = cmp.Or(p.VisionModel.NumChannels, 3)
|
||||||
|
kv["gemma3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
|
||||||
|
kv["gemma3.vision.attention.layer_norm_epsilon"] = cmp.Or(p.VisionModel.LayerNormEpsilon, 1e-6)
|
||||||
|
kv["gemma3.attention.key_length"] = cmp.Or(p.TextModel.HeadDim, 256)
|
||||||
|
kv["gemma3.attention.value_length"] = cmp.Or(p.TextModel.HeadDim, 256)
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.MultiModalTokensPerImage > 0 {
|
||||||
|
kv["gemma3.mm.tokens_per_image"] = p.MultiModalTokensPerImage
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *gemma3Model) Replacements() []string {
|
||||||
|
return []string{
|
||||||
|
"lm_head", "output",
|
||||||
|
"model.embed_tokens", "token_embd",
|
||||||
|
"model.norm", "output_norm",
|
||||||
|
"vision_tower.vision_model.embeddings", "v",
|
||||||
|
"vision_tower.vision_model", "v",
|
||||||
|
"vision_model.vision_model.embeddings", "v",
|
||||||
|
"vision_model.vision_model", "v",
|
||||||
|
"language_model.", "",
|
||||||
|
"model.layers", "blk",
|
||||||
|
"encoder.layers", "blk",
|
||||||
|
"input_layernorm", "attn_norm",
|
||||||
|
"self_attn.q_proj", "attn_q",
|
||||||
|
"self_attn.q_norm", "attn_q_norm",
|
||||||
|
"self_attn.k_proj", "attn_k",
|
||||||
|
"self_attn.k_norm", "attn_k_norm",
|
||||||
|
"self_attn.v_proj", "attn_v",
|
||||||
|
"self_attn.o_proj", "attn_output",
|
||||||
|
"self_attn.out_proj", "attn_output",
|
||||||
|
"mlp.gate_proj", "ffn_gate",
|
||||||
|
"mlp.down_proj", "ffn_down",
|
||||||
|
"mlp.up_proj", "ffn_up",
|
||||||
|
"post_attention_layernorm", "post_attention_norm",
|
||||||
|
"pre_feedforward_layernorm", "ffn_norm",
|
||||||
|
"post_feedforward_layernorm", "post_ffw_norm",
|
||||||
|
"input_projection_weight", "input_projection.weight",
|
||||||
|
"multi_modal_projector", "mm",
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
"github.com/pdevine/tensor"
|
"github.com/pdevine/tensor"
|
||||||
"github.com/pdevine/tensor/native"
|
"github.com/pdevine/tensor/native"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type llamaModel struct {
|
type llamaModel struct {
|
||||||
@@ -46,7 +46,7 @@ type llamaModel struct {
|
|||||||
|
|
||||||
var _ ModelConverter = (*llamaModel)(nil)
|
var _ ModelConverter = (*llamaModel)(nil)
|
||||||
|
|
||||||
func (p *llamaModel) KV(t *Tokenizer) llm.KV {
|
func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
|
||||||
kv := p.ModelParameters.KV(t)
|
kv := p.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "llama"
|
kv["general.architecture"] = "llama"
|
||||||
kv["llama.vocab_size"] = p.VocabSize
|
kv["llama.vocab_size"] = p.VocabSize
|
||||||
@@ -120,11 +120,11 @@ func (p *llamaModel) KV(t *Tokenizer) llm.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *llamaModel) Tensors(ts []Tensor) []llm.Tensor {
|
func (p *llamaModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||||
var out []llm.Tensor
|
var out []ggml.Tensor
|
||||||
|
|
||||||
if p.RopeScaling.factors != nil {
|
if p.RopeScaling.factors != nil {
|
||||||
out = append(out, llm.Tensor{
|
out = append(out, ggml.Tensor{
|
||||||
Name: "rope_freqs.weight",
|
Name: "rope_freqs.weight",
|
||||||
Kind: 0,
|
Kind: 0,
|
||||||
Shape: []uint64{uint64(len(p.RopeScaling.factors))},
|
Shape: []uint64{uint64(len(p.RopeScaling.factors))},
|
||||||
@@ -138,7 +138,7 @@ func (p *llamaModel) Tensors(ts []Tensor) []llm.Tensor {
|
|||||||
t.SetRepacker(p.repack)
|
t.SetRepacker(p.repack)
|
||||||
}
|
}
|
||||||
|
|
||||||
out = append(out, llm.Tensor{
|
out = append(out, ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: t.Shape(),
|
Shape: t.Shape(),
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
"github.com/pdevine/tensor"
|
"github.com/pdevine/tensor"
|
||||||
"github.com/pdevine/tensor/native"
|
"github.com/pdevine/tensor/native"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type llamaAdapter struct {
|
type llamaAdapter struct {
|
||||||
@@ -18,7 +18,7 @@ type llamaAdapter struct {
|
|||||||
|
|
||||||
var _ AdapterConverter = (*llamaAdapter)(nil)
|
var _ AdapterConverter = (*llamaAdapter)(nil)
|
||||||
|
|
||||||
func (p *llamaAdapter) KV(baseKV llm.KV) llm.KV {
|
func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
|
||||||
kv := p.AdapterParameters.KV()
|
kv := p.AdapterParameters.KV()
|
||||||
kv["general.architecture"] = "llama"
|
kv["general.architecture"] = "llama"
|
||||||
kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"]
|
kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"]
|
||||||
@@ -29,8 +29,8 @@ func (p *llamaAdapter) KV(baseKV llm.KV) llm.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *llamaAdapter) Tensors(ts []Tensor) []llm.Tensor {
|
func (p *llamaAdapter) Tensors(ts []Tensor) []ggml.Tensor {
|
||||||
var out []llm.Tensor
|
var out []ggml.Tensor
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
shape := t.Shape()
|
shape := t.Shape()
|
||||||
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
|
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
|
||||||
@@ -41,7 +41,7 @@ func (p *llamaAdapter) Tensors(ts []Tensor) []llm.Tensor {
|
|||||||
t.SetRepacker(p.repack)
|
t.SetRepacker(p.repack)
|
||||||
}
|
}
|
||||||
|
|
||||||
out = append(out, llm.Tensor{
|
out = append(out, ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: shape,
|
Shape: shape,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mixtralModel struct {
|
type mixtralModel struct {
|
||||||
@@ -15,7 +15,7 @@ type mixtralModel struct {
|
|||||||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *mixtralModel) KV(t *Tokenizer) llm.KV {
|
func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
|
||||||
kv := p.llamaModel.KV(t)
|
kv := p.llamaModel.KV(t)
|
||||||
|
|
||||||
if p.NumLocalExperts > 0 {
|
if p.NumLocalExperts > 0 {
|
||||||
@@ -29,7 +29,7 @@ func (p *mixtralModel) KV(t *Tokenizer) llm.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *mixtralModel) Tensors(ts []Tensor) []llm.Tensor {
|
func (p *mixtralModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||||
oldnew := []string{
|
oldnew := []string{
|
||||||
"model.layers", "blk",
|
"model.layers", "blk",
|
||||||
"w1", "ffn_gate_exps",
|
"w1", "ffn_gate_exps",
|
||||||
@@ -56,10 +56,10 @@ func (p *mixtralModel) Tensors(ts []Tensor) []llm.Tensor {
|
|||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
var out []llm.Tensor
|
var out []ggml.Tensor
|
||||||
for n, e := range experts {
|
for n, e := range experts {
|
||||||
// TODO(mxyng): sanity check experts
|
// TODO(mxyng): sanity check experts
|
||||||
out = append(out, llm.Tensor{
|
out = append(out, ggml.Tensor{
|
||||||
Name: n,
|
Name: n,
|
||||||
Kind: e[0].Kind(),
|
Kind: e[0].Kind(),
|
||||||
Shape: append([]uint64{uint64(len(e))}, e[0].Shape()...),
|
Shape: append([]uint64{uint64(len(e))}, e[0].Shape()...),
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type phi3Model struct {
|
type phi3Model struct {
|
||||||
@@ -37,7 +37,7 @@ type phi3Model struct {
|
|||||||
|
|
||||||
var _ ModelConverter = (*phi3Model)(nil)
|
var _ ModelConverter = (*phi3Model)(nil)
|
||||||
|
|
||||||
func (p *phi3Model) KV(t *Tokenizer) llm.KV {
|
func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
|
||||||
kv := p.ModelParameters.KV(t)
|
kv := p.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "phi3"
|
kv["general.architecture"] = "phi3"
|
||||||
kv["phi3.context_length"] = p.MaxPositionEmbeddings
|
kv["phi3.context_length"] = p.MaxPositionEmbeddings
|
||||||
@@ -68,19 +68,19 @@ func (p *phi3Model) KV(t *Tokenizer) llm.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *phi3Model) Tensors(ts []Tensor) []llm.Tensor {
|
func (p *phi3Model) Tensors(ts []Tensor) []ggml.Tensor {
|
||||||
var addRopeFactors sync.Once
|
var addRopeFactors sync.Once
|
||||||
|
|
||||||
out := make([]llm.Tensor, 0, len(ts)+2)
|
out := make([]ggml.Tensor, 0, len(ts)+2)
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
if strings.HasPrefix(t.Name(), "blk.0.") {
|
if strings.HasPrefix(t.Name(), "blk.0.") {
|
||||||
addRopeFactors.Do(func() {
|
addRopeFactors.Do(func() {
|
||||||
out = append(out, llm.Tensor{
|
out = append(out, ggml.Tensor{
|
||||||
Name: "rope_factors_long.weight",
|
Name: "rope_factors_long.weight",
|
||||||
Kind: 0,
|
Kind: 0,
|
||||||
Shape: []uint64{uint64(len(p.RopeScaling.LongFactor))},
|
Shape: []uint64{uint64(len(p.RopeScaling.LongFactor))},
|
||||||
WriterTo: p.RopeScaling.LongFactor,
|
WriterTo: p.RopeScaling.LongFactor,
|
||||||
}, llm.Tensor{
|
}, ggml.Tensor{
|
||||||
Name: "rope_factors_short.weight",
|
Name: "rope_factors_short.weight",
|
||||||
Kind: 0,
|
Kind: 0,
|
||||||
Shape: []uint64{uint64(len(p.RopeScaling.ShortFactor))},
|
Shape: []uint64{uint64(len(p.RopeScaling.ShortFactor))},
|
||||||
@@ -89,7 +89,7 @@ func (p *phi3Model) Tensors(ts []Tensor) []llm.Tensor {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
out = append(out, llm.Tensor{
|
out = append(out, ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: t.Shape(),
|
Shape: t.Shape(),
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package convert
|
package convert
|
||||||
|
|
||||||
import "github.com/ollama/ollama/llm"
|
import "github.com/ollama/ollama/fs/ggml"
|
||||||
|
|
||||||
type qwen2Model struct {
|
type qwen2Model struct {
|
||||||
ModelParameters
|
ModelParameters
|
||||||
@@ -21,7 +21,7 @@ type qwen2Model struct {
|
|||||||
|
|
||||||
var _ ModelConverter = (*qwen2Model)(nil)
|
var _ ModelConverter = (*qwen2Model)(nil)
|
||||||
|
|
||||||
func (q *qwen2Model) KV(t *Tokenizer) llm.KV {
|
func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
|
||||||
kv := q.ModelParameters.KV(t)
|
kv := q.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "qwen2"
|
kv["general.architecture"] = "qwen2"
|
||||||
kv["qwen2.block_count"] = q.HiddenLayers
|
kv["qwen2.block_count"] = q.HiddenLayers
|
||||||
@@ -45,10 +45,10 @@ func (q *qwen2Model) KV(t *Tokenizer) llm.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *qwen2Model) Tensors(ts []Tensor) []llm.Tensor {
|
func (q *qwen2Model) Tensors(ts []Tensor) []ggml.Tensor {
|
||||||
var out []llm.Tensor
|
var out []ggml.Tensor
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
out = append(out, llm.Tensor{
|
out = append(out, ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: t.Shape(),
|
Shape: t.Shape(),
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import (
|
|||||||
|
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type tensorData struct {
|
type tensorData struct {
|
||||||
@@ -29,7 +29,7 @@ type tensorData struct {
|
|||||||
Shape []int `json:"shape"`
|
Shape []int `json:"shape"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, *llm.Tensors) {
|
func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
f, err := os.CreateTemp(t.TempDir(), "f16")
|
f, err := os.CreateTemp(t.TempDir(), "f16")
|
||||||
@@ -48,7 +48,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, *llm.Tensors) {
|
|||||||
}
|
}
|
||||||
t.Cleanup(func() { r.Close() })
|
t.Cleanup(func() { r.Close() })
|
||||||
|
|
||||||
m, _, err := llm.DecodeGGML(r, math.MaxInt)
|
m, _, err := ggml.Decode(r, math.MaxInt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -60,7 +60,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, *llm.Tensors) {
|
|||||||
return r, m.KV(), m.Tensors()
|
return r, m.KV(), m.Tensors()
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateResultsJSON(t *testing.T, f *os.File, kv llm.KV, tensors *llm.Tensors) map[string]string {
|
func generateResultsJSON(t *testing.T, f *os.File, kv ggml.KV, tensors ggml.Tensors) map[string]string {
|
||||||
actual := make(map[string]string)
|
actual := make(map[string]string)
|
||||||
for k, v := range kv {
|
for k, v := range kv {
|
||||||
if s, ok := v.(json.Marshaler); !ok {
|
if s, ok := v.(json.Marshaler); !ok {
|
||||||
@@ -75,7 +75,7 @@ func generateResultsJSON(t *testing.T, f *os.File, kv llm.KV, tensors *llm.Tenso
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tensor := range tensors.Items {
|
for _, tensor := range tensors.Items() {
|
||||||
sha256sum := sha256.New()
|
sha256sum := sha256.New()
|
||||||
sr := io.NewSectionReader(f, int64(tensors.Offset+tensor.Offset), int64(tensor.Size()))
|
sr := io.NewSectionReader(f, int64(tensors.Offset+tensor.Offset), int64(tensor.Size()))
|
||||||
if _, err := io.Copy(sha256sum, sr); err != nil {
|
if _, err := io.Copy(sha256sum, sr); err != nil {
|
||||||
@@ -332,7 +332,7 @@ func TestConvertAdapter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer r.Close()
|
defer r.Close()
|
||||||
|
|
||||||
m, _, err := llm.DecodeGGML(r, math.MaxInt)
|
m, _, err := ggml.Decode(r, math.MaxInt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,9 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
|
"reflect"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
@@ -15,6 +17,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
||||||
|
slog.Debug("using spm vocabulary")
|
||||||
|
|
||||||
ast, err := parseAdditionalSpecialTokens(fsys)
|
ast, err := parseAdditionalSpecialTokens(fsys)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -43,10 +47,19 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
|||||||
v.Types = append(v.Types, int32(t))
|
v.Types = append(v.Types, int32(t))
|
||||||
default:
|
default:
|
||||||
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
|
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
|
||||||
if slices.Contains(ast, piece.GetPiece()) {
|
|
||||||
|
// temporary fix to handle gemma3 broken configs
|
||||||
|
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>"}, piece.GetPiece()) {
|
||||||
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
|
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, t := range ast {
|
||||||
|
if t.Content == piece.GetPiece() {
|
||||||
|
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
v.Types = append(v.Types, tt)
|
v.Types = append(v.Types, tt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -78,10 +91,16 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
|||||||
return cmp.Compare(i.id, j.id)
|
return cmp.Compare(i.id, j.id)
|
||||||
})
|
})
|
||||||
|
|
||||||
n := len(v.Tokens)
|
for _, t := range ts {
|
||||||
for i, t := range ts {
|
if t.id < len(v.Tokens) {
|
||||||
if t.id != i+n {
|
if v.Tokens[t.id] == t.content {
|
||||||
return nil, fmt.Errorf("invalid token id: %d", t.id)
|
slog.Warn("tokenizer", "duplicate token", t.content, "id", t.id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("token mismatch: %s != %s at pos [%d]", t.content, v.Tokens[t.id], t.id)
|
||||||
|
}
|
||||||
|
if t.id != len(v.Tokens) {
|
||||||
|
return nil, fmt.Errorf("invalid token id: [%d] as pos [%d]", t.id, len(v.Tokens))
|
||||||
}
|
}
|
||||||
|
|
||||||
v.Tokens = append(v.Tokens, t.content)
|
v.Tokens = append(v.Tokens, t.content)
|
||||||
@@ -92,7 +111,15 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
|||||||
return &v, nil
|
return &v, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseAdditionalSpecialTokens(fsys fs.FS) ([]string, error) {
|
type specialToken struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
Lstrip bool `json:"lstrip"`
|
||||||
|
Normalized bool `json:"normalized"`
|
||||||
|
Rstrip bool `json:"rstrip"`
|
||||||
|
SingleWord bool `json:"single_word"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAdditionalSpecialTokens(fsys fs.FS) ([]specialToken, error) {
|
||||||
f, err := fsys.Open("special_tokens_map.json")
|
f, err := fsys.Open("special_tokens_map.json")
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@@ -102,12 +129,43 @@ func parseAdditionalSpecialTokens(fsys fs.FS) ([]string, error) {
|
|||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
var m struct {
|
var m struct {
|
||||||
AdditionalSpecialTokens []string `json:"additional_special_tokens"`
|
AdditionalSpecialTokens any `json:"additional_special_tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.NewDecoder(f).Decode(&m); err != nil {
|
if err := json.NewDecoder(f).Decode(&m); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.AdditionalSpecialTokens, nil
|
var ast []specialToken
|
||||||
|
|
||||||
|
switch st := m.AdditionalSpecialTokens.(type) {
|
||||||
|
case []string:
|
||||||
|
for _, s := range st {
|
||||||
|
ast = append(ast, specialToken{Content: s})
|
||||||
|
}
|
||||||
|
case []any:
|
||||||
|
for _, s := range st {
|
||||||
|
// marshal and unmarshal the object to get the special token
|
||||||
|
tMap := s.(map[string]any)
|
||||||
|
data, err := json.Marshal(tMap)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var token specialToken
|
||||||
|
err = json.Unmarshal(data, &token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ast = append(ast, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
slog.Warn("special token", "unknown token", reflect.TypeOf(st))
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("spm tokenizer", "additional tokens", ast)
|
||||||
|
|
||||||
|
return ast, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -57,7 +57,8 @@ func cudaVariant(gpuInfo CudaGPUInfo) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if gpuInfo.computeMajor < 6 || gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
|
// driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers
|
||||||
|
if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
|
||||||
return "v11"
|
return "v11"
|
||||||
}
|
}
|
||||||
return "v12"
|
return "v12"
|
||||||
|
|||||||
@@ -19,9 +19,8 @@ var LibOllamaPath string = func() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
exe, err = filepath.EvalSymlinks(exe)
|
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||||
if err != nil {
|
exe = eval
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var libPath string
|
var libPath string
|
||||||
|
|||||||
37
docs/api.md
37
docs/api.md
@@ -31,7 +31,7 @@ Certain endpoints stream responses as JSON objects. Streaming can be disabled by
|
|||||||
|
|
||||||
## Generate a completion
|
## Generate a completion
|
||||||
|
|
||||||
```shell
|
```
|
||||||
POST /api/generate
|
POST /api/generate
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -485,7 +485,7 @@ A single JSON object is returned:
|
|||||||
|
|
||||||
## Generate a chat completion
|
## Generate a chat completion
|
||||||
|
|
||||||
```shell
|
```
|
||||||
POST /api/chat
|
POST /api/chat
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -558,6 +558,10 @@ Final response:
|
|||||||
{
|
{
|
||||||
"model": "llama3.2",
|
"model": "llama3.2",
|
||||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": ""
|
||||||
|
},
|
||||||
"done": true,
|
"done": true,
|
||||||
"total_duration": 4883583458,
|
"total_duration": 4883583458,
|
||||||
"load_duration": 1334875,
|
"load_duration": 1334875,
|
||||||
@@ -878,6 +882,7 @@ curl http://localhost:11434/api/chat -d '{
|
|||||||
```
|
```
|
||||||
|
|
||||||
##### Response
|
##### Response
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"model": "llama3.2",
|
"model": "llama3.2",
|
||||||
@@ -924,7 +929,7 @@ A single JSON object is returned:
|
|||||||
|
|
||||||
## Create a Model
|
## Create a Model
|
||||||
|
|
||||||
```shell
|
```
|
||||||
POST /api/create
|
POST /api/create
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1020,7 +1025,7 @@ curl http://localhost:11434/api/create -d '{
|
|||||||
|
|
||||||
A stream of JSON objects is returned:
|
A stream of JSON objects is returned:
|
||||||
|
|
||||||
```
|
```json
|
||||||
{"status":"quantizing F16 model to Q4_K_M"}
|
{"status":"quantizing F16 model to Q4_K_M"}
|
||||||
{"status":"creating new layer sha256:667b0c1932bc6ffc593ed1d03f895bf2dc8dc6df21db3042284a6f4416b06a29"}
|
{"status":"creating new layer sha256:667b0c1932bc6ffc593ed1d03f895bf2dc8dc6df21db3042284a6f4416b06a29"}
|
||||||
{"status":"using existing layer sha256:11ce4ee3e170f6adebac9a991c22e22ab3f8530e154ee669954c4bc73061c258"}
|
{"status":"using existing layer sha256:11ce4ee3e170f6adebac9a991c22e22ab3f8530e154ee669954c4bc73061c258"}
|
||||||
@@ -1051,7 +1056,7 @@ curl http://localhost:11434/api/create -d '{
|
|||||||
|
|
||||||
A stream of JSON objects is returned:
|
A stream of JSON objects is returned:
|
||||||
|
|
||||||
```
|
```json
|
||||||
{"status":"parsing GGUF"}
|
{"status":"parsing GGUF"}
|
||||||
{"status":"using existing layer sha256:432f310a77f4650a88d0fd59ecdd7cebed8d684bafea53cbff0473542964f0c3"}
|
{"status":"using existing layer sha256:432f310a77f4650a88d0fd59ecdd7cebed8d684bafea53cbff0473542964f0c3"}
|
||||||
{"status":"writing manifest"}
|
{"status":"writing manifest"}
|
||||||
@@ -1118,7 +1123,7 @@ Return 200 OK if the blob exists, 404 Not Found if it does not.
|
|||||||
|
|
||||||
## Push a Blob
|
## Push a Blob
|
||||||
|
|
||||||
```shell
|
```
|
||||||
POST /api/blobs/:digest
|
POST /api/blobs/:digest
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1142,7 +1147,7 @@ Return 201 Created if the blob was successfully created, 400 Bad Request if the
|
|||||||
|
|
||||||
## List Local Models
|
## List Local Models
|
||||||
|
|
||||||
```shell
|
```
|
||||||
GET /api/tags
|
GET /api/tags
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1195,7 +1200,7 @@ A single JSON object will be returned.
|
|||||||
|
|
||||||
## Show Model Information
|
## Show Model Information
|
||||||
|
|
||||||
```shell
|
```
|
||||||
POST /api/show
|
POST /api/show
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1261,7 +1266,7 @@ curl http://localhost:11434/api/show -d '{
|
|||||||
|
|
||||||
## Copy a Model
|
## Copy a Model
|
||||||
|
|
||||||
```shell
|
```
|
||||||
POST /api/copy
|
POST /api/copy
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1284,7 +1289,7 @@ Returns a 200 OK if successful, or a 404 Not Found if the source model doesn't e
|
|||||||
|
|
||||||
## Delete a Model
|
## Delete a Model
|
||||||
|
|
||||||
```shell
|
```
|
||||||
DELETE /api/delete
|
DELETE /api/delete
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1310,7 +1315,7 @@ Returns a 200 OK if successful, 404 Not Found if the model to be deleted doesn't
|
|||||||
|
|
||||||
## Pull a Model
|
## Pull a Model
|
||||||
|
|
||||||
```shell
|
```
|
||||||
POST /api/pull
|
POST /api/pull
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1382,7 +1387,7 @@ if `stream` is set to false, then the response is a single JSON object:
|
|||||||
|
|
||||||
## Push a Model
|
## Push a Model
|
||||||
|
|
||||||
```shell
|
```
|
||||||
POST /api/push
|
POST /api/push
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1447,7 +1452,7 @@ If `stream` is set to `false`, then the response is a single JSON object:
|
|||||||
|
|
||||||
## Generate Embeddings
|
## Generate Embeddings
|
||||||
|
|
||||||
```shell
|
```
|
||||||
POST /api/embed
|
POST /api/embed
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1515,7 +1520,7 @@ curl http://localhost:11434/api/embed -d '{
|
|||||||
```
|
```
|
||||||
|
|
||||||
## List Running Models
|
## List Running Models
|
||||||
```shell
|
```
|
||||||
GET /api/ps
|
GET /api/ps
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1562,7 +1567,7 @@ A single JSON object will be returned.
|
|||||||
|
|
||||||
> Note: this endpoint has been superseded by `/api/embed`
|
> Note: this endpoint has been superseded by `/api/embed`
|
||||||
|
|
||||||
```shell
|
```
|
||||||
POST /api/embeddings
|
POST /api/embeddings
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1602,7 +1607,7 @@ curl http://localhost:11434/api/embeddings -d '{
|
|||||||
|
|
||||||
## Version
|
## Version
|
||||||
|
|
||||||
```shell
|
```
|
||||||
GET /api/version
|
GET /api/version
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
59
docs/benchmark.md
Normal file
59
docs/benchmark.md
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
# Benchmark
|
||||||
|
|
||||||
|
Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes.
|
||||||
|
|
||||||
|
## When to use
|
||||||
|
|
||||||
|
Run these benchmarks when:
|
||||||
|
- Making changes to the model inference engine
|
||||||
|
- Modifying model loading/unloading logic
|
||||||
|
- Changing prompt processing or token generation code
|
||||||
|
- Implementing a new model architecture
|
||||||
|
- Testing performance across different hardware setups
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
- Ollama server running locally with `ollama serve` on `127.0.0.1:11434`
|
||||||
|
## Usage and Examples
|
||||||
|
|
||||||
|
>[!NOTE]
|
||||||
|
>All commands must be run from the root directory of the Ollama project.
|
||||||
|
|
||||||
|
Basic syntax:
|
||||||
|
```bash
|
||||||
|
go test -bench=. ./benchmark/... -m $MODEL_NAME
|
||||||
|
```
|
||||||
|
|
||||||
|
Required flags:
|
||||||
|
- `-bench=.`: Run all benchmarks
|
||||||
|
- `-m`: Model name to benchmark
|
||||||
|
|
||||||
|
Optional flags:
|
||||||
|
- `-count N`: Number of times to run the benchmark (useful for statistical analysis)
|
||||||
|
- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes)
|
||||||
|
|
||||||
|
Common usage patterns:
|
||||||
|
|
||||||
|
Single benchmark run with a model specified:
|
||||||
|
```bash
|
||||||
|
go test -bench=. ./benchmark/... -m llama3.3
|
||||||
|
```
|
||||||
|
|
||||||
|
## Output metrics
|
||||||
|
|
||||||
|
The benchmark reports several key metrics:
|
||||||
|
|
||||||
|
- `gen_tok/s`: Generated tokens per second
|
||||||
|
- `prompt_tok/s`: Prompt processing tokens per second
|
||||||
|
- `ttft_ms`: Time to first token in milliseconds
|
||||||
|
- `load_ms`: Model load time in milliseconds
|
||||||
|
- `gen_tokens`: Total tokens generated
|
||||||
|
- `prompt_tokens`: Total prompt tokens processed
|
||||||
|
|
||||||
|
Each benchmark runs two scenarios:
|
||||||
|
- Cold start: Model is loaded from disk for each test
|
||||||
|
- Warm start: Model is pre-loaded in memory
|
||||||
|
|
||||||
|
Three prompt lengths are tested for each scenario:
|
||||||
|
- Short prompt (100 tokens)
|
||||||
|
- Medium prompt (500 tokens)
|
||||||
|
- Long prompt (1000 tokens)
|
||||||
@@ -3,11 +3,11 @@
|
|||||||
Install prerequisites:
|
Install prerequisites:
|
||||||
|
|
||||||
- [Go](https://go.dev/doc/install)
|
- [Go](https://go.dev/doc/install)
|
||||||
- C/C++ Compiler e.g. Clang on macOS, [TDM-GCC](https://jmeubank.github.io/tdm-gcc/download/) (Windows amd64) or [llvm-mingw](https://github.com/mstorsjo/llvm-mingw) (Windows arm64), GCC/Clang on Linux.
|
- C/C++ Compiler e.g. Clang on macOS, [TDM-GCC](https://github.com/jmeubank/tdm-gcc/releases/latest) (Windows amd64) or [llvm-mingw](https://github.com/mstorsjo/llvm-mingw) (Windows arm64), GCC/Clang on Linux.
|
||||||
|
|
||||||
Then build and run Ollama from the root directory of the repository:
|
Then build and run Ollama from the root directory of the repository:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
go run . serve
|
go run . serve
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -23,14 +23,14 @@ Install prerequisites:
|
|||||||
|
|
||||||
Then, configure and build the project:
|
Then, configure and build the project:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
cmake -B build
|
cmake -B build
|
||||||
cmake --build build
|
cmake --build build
|
||||||
```
|
```
|
||||||
|
|
||||||
Lastly, run Ollama:
|
Lastly, run Ollama:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
go run . serve
|
go run . serve
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -41,36 +41,35 @@ Install prerequisites:
|
|||||||
- [CMake](https://cmake.org/download/)
|
- [CMake](https://cmake.org/download/)
|
||||||
- [Visual Studio 2022](https://visualstudio.microsoft.com/downloads/) including the Native Desktop Workload
|
- [Visual Studio 2022](https://visualstudio.microsoft.com/downloads/) including the Native Desktop Workload
|
||||||
- (Optional) AMD GPU support
|
- (Optional) AMD GPU support
|
||||||
- [ROCm](https://rocm.github.io/install.html)
|
- [ROCm](https://rocm.docs.amd.com/en/latest/)
|
||||||
- [Ninja](https://github.com/ninja-build/ninja/releases)
|
- [Ninja](https://github.com/ninja-build/ninja/releases)
|
||||||
- (Optional) NVIDIA GPU support
|
- (Optional) NVIDIA GPU support
|
||||||
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
|
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
|
||||||
|
|
||||||
> [!IMPORTANT]
|
|
||||||
> Ensure prerequisites are in `PATH` before running CMake.
|
|
||||||
|
|
||||||
> [!IMPORTANT]
|
|
||||||
> ROCm is not compatible with Visual Studio CMake generators. Use `-GNinja` when configuring the project.
|
|
||||||
|
|
||||||
> [!IMPORTANT]
|
|
||||||
> CUDA is only compatible with Visual Studio CMake generators.
|
|
||||||
|
|
||||||
Then, configure and build the project:
|
Then, configure and build the project:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
cmake -B build
|
cmake -B build
|
||||||
cmake --build build --config Release
|
cmake --build build --config Release
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> Building for ROCm requires additional flags:
|
||||||
|
> ```
|
||||||
|
> cmake -B build -G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++
|
||||||
|
> cmake --build build --config Release
|
||||||
|
> ```
|
||||||
|
|
||||||
|
|
||||||
Lastly, run Ollama:
|
Lastly, run Ollama:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
go run . serve
|
go run . serve
|
||||||
```
|
```
|
||||||
|
|
||||||
## Windows (ARM)
|
## Windows (ARM)
|
||||||
|
|
||||||
Windows ARM does not support additional acceleration libraries at this time.
|
Windows ARM does not support additional acceleration libraries at this time. Do not use cmake, simply `go run` or `go build`.
|
||||||
|
|
||||||
## Linux
|
## Linux
|
||||||
|
|
||||||
@@ -88,26 +87,26 @@ Install prerequisites:
|
|||||||
|
|
||||||
Then, configure and build the project:
|
Then, configure and build the project:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
cmake -B build
|
cmake -B build
|
||||||
cmake --build build
|
cmake --build build
|
||||||
```
|
```
|
||||||
|
|
||||||
Lastly, run Ollama:
|
Lastly, run Ollama:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
go run . serve
|
go run . serve
|
||||||
```
|
```
|
||||||
|
|
||||||
## Docker
|
## Docker
|
||||||
|
|
||||||
```
|
```shell
|
||||||
docker build .
|
docker build .
|
||||||
```
|
```
|
||||||
|
|
||||||
### ROCm
|
### ROCm
|
||||||
|
|
||||||
```
|
```shell
|
||||||
docker build --build-arg FLAVOR=rocm .
|
docker build --build-arg FLAVOR=rocm .
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -115,6 +114,46 @@ docker build --build-arg FLAVOR=rocm .
|
|||||||
|
|
||||||
To run tests, use `go test`:
|
To run tests, use `go test`:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
go test ./...
|
go test ./...
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> NOTE: In rare cirumstances, you may nedd to change a package using the new
|
||||||
|
> "synctest" package in go1.24.
|
||||||
|
>
|
||||||
|
> If you do not have the "synctest" package enabled, you will not see build or
|
||||||
|
> test failures resulting from your change(s), if any, locally, but CI will
|
||||||
|
> break.
|
||||||
|
>
|
||||||
|
> If you see failures in CI, you can either keep pushing changes to see if the
|
||||||
|
> CI build passes, or you can enable the "synctest" package locally to see the
|
||||||
|
> failures before pushing.
|
||||||
|
>
|
||||||
|
> To enable the "synctest" package for testing, run the following command:
|
||||||
|
>
|
||||||
|
> ```shell
|
||||||
|
> GOEXPERIMENT=synctest go test ./...
|
||||||
|
> ```
|
||||||
|
>
|
||||||
|
> If you wish to enable synctest for all go commands, you can set the
|
||||||
|
> `GOEXPERIMENT` environment variable in your shell profile or by using:
|
||||||
|
>
|
||||||
|
> ```shell
|
||||||
|
> go env -w GOEXPERIMENT=synctest
|
||||||
|
> ```
|
||||||
|
>
|
||||||
|
> Which will enable the "synctest" package for all go commands without needing
|
||||||
|
> to set it for all shell sessions.
|
||||||
|
>
|
||||||
|
> The synctest package is not required for production builds.
|
||||||
|
|
||||||
|
## Library detection
|
||||||
|
|
||||||
|
Ollama looks for acceleration libraries in the following paths relative to the `ollama` executable:
|
||||||
|
|
||||||
|
* `./lib/ollama` (Windows)
|
||||||
|
* `../lib/ollama` (Linux)
|
||||||
|
* `.` (macOS)
|
||||||
|
* `build/lib/ollama` (for development)
|
||||||
|
|
||||||
|
If the libraries are not found, Ollama will not run with any acceleration libraries.
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
### CPU only
|
### CPU only
|
||||||
|
|
||||||
```bash
|
```shell
|
||||||
docker run -d -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama
|
docker run -d -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -11,7 +11,8 @@ Install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-
|
|||||||
|
|
||||||
#### Install with Apt
|
#### Install with Apt
|
||||||
1. Configure the repository
|
1. Configure the repository
|
||||||
```bash
|
|
||||||
|
```shell
|
||||||
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey \
|
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey \
|
||||||
| sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg
|
| sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg
|
||||||
curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list \
|
curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list \
|
||||||
@@ -19,34 +20,37 @@ curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-contai
|
|||||||
| sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
|
| sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Install the NVIDIA Container Toolkit packages
|
2. Install the NVIDIA Container Toolkit packages
|
||||||
```bash
|
|
||||||
|
```shell
|
||||||
sudo apt-get install -y nvidia-container-toolkit
|
sudo apt-get install -y nvidia-container-toolkit
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Install with Yum or Dnf
|
#### Install with Yum or Dnf
|
||||||
1. Configure the repository
|
1. Configure the repository
|
||||||
|
|
||||||
```bash
|
```shell
|
||||||
curl -s -L https://nvidia.github.io/libnvidia-container/stable/rpm/nvidia-container-toolkit.repo \
|
curl -s -L https://nvidia.github.io/libnvidia-container/stable/rpm/nvidia-container-toolkit.repo \
|
||||||
| sudo tee /etc/yum.repos.d/nvidia-container-toolkit.repo
|
| sudo tee /etc/yum.repos.d/nvidia-container-toolkit.repo
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Install the NVIDIA Container Toolkit packages
|
2. Install the NVIDIA Container Toolkit packages
|
||||||
|
|
||||||
```bash
|
```shell
|
||||||
sudo yum install -y nvidia-container-toolkit
|
sudo yum install -y nvidia-container-toolkit
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Configure Docker to use Nvidia driver
|
#### Configure Docker to use Nvidia driver
|
||||||
```
|
|
||||||
|
```shell
|
||||||
sudo nvidia-ctk runtime configure --runtime=docker
|
sudo nvidia-ctk runtime configure --runtime=docker
|
||||||
sudo systemctl restart docker
|
sudo systemctl restart docker
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Start the container
|
#### Start the container
|
||||||
|
|
||||||
```bash
|
```shell
|
||||||
docker run -d --gpus=all -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama
|
docker run -d --gpus=all -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -57,7 +61,7 @@ docker run -d --gpus=all -v ollama:/root/.ollama -p 11434:11434 --name ollama ol
|
|||||||
|
|
||||||
To run Ollama using Docker with AMD GPUs, use the `rocm` tag and the following command:
|
To run Ollama using Docker with AMD GPUs, use the `rocm` tag and the following command:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
docker run -d --device /dev/kfd --device /dev/dri -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama:rocm
|
docker run -d --device /dev/kfd --device /dev/dri -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama:rocm
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -65,7 +69,7 @@ docker run -d --device /dev/kfd --device /dev/dri -v ollama:/root/.ollama -p 114
|
|||||||
|
|
||||||
Now you can run a model:
|
Now you can run a model:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
docker exec -it ollama ollama run llama3.2
|
docker exec -it ollama ollama run llama3.2
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
31
docs/faq.md
31
docs/faq.md
@@ -20,11 +20,11 @@ Please refer to the [GPU docs](./gpu.md).
|
|||||||
|
|
||||||
## How can I specify the context window size?
|
## How can I specify the context window size?
|
||||||
|
|
||||||
By default, Ollama uses a context window size of 2048 tokens.
|
By default, Ollama uses a context window size of 2048 tokens. This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context length to 8K, use: `OLLAMA_CONTEXT_LENGTH=8192 ollama serve`.
|
||||||
|
|
||||||
To change this when using `ollama run`, use `/set parameter`:
|
To change this when using `ollama run`, use `/set parameter`:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
/set parameter num_ctx 4096
|
/set parameter num_ctx 4096
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -46,10 +46,15 @@ Use the `ollama ps` command to see what models are currently loaded into memory.
|
|||||||
|
|
||||||
```shell
|
```shell
|
||||||
ollama ps
|
ollama ps
|
||||||
NAME ID SIZE PROCESSOR UNTIL
|
|
||||||
llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> **Output**:
|
||||||
|
>
|
||||||
|
> ```
|
||||||
|
> NAME ID SIZE PROCESSOR UNTIL
|
||||||
|
> llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
|
||||||
|
> ```
|
||||||
|
|
||||||
The `Processor` column will show which memory the model was loaded in to:
|
The `Processor` column will show which memory the model was loaded in to:
|
||||||
* `100% GPU` means the model was loaded entirely into the GPU
|
* `100% GPU` means the model was loaded entirely into the GPU
|
||||||
* `100% CPU` means the model was loaded entirely in system memory
|
* `100% CPU` means the model was loaded entirely in system memory
|
||||||
@@ -66,7 +71,7 @@ If Ollama is run as a macOS application, environment variables should be set usi
|
|||||||
1. For each environment variable, call `launchctl setenv`.
|
1. For each environment variable, call `launchctl setenv`.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
launchctl setenv OLLAMA_HOST "0.0.0.0"
|
launchctl setenv OLLAMA_HOST "0.0.0.0:11434"
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Restart Ollama application.
|
2. Restart Ollama application.
|
||||||
@@ -81,14 +86,14 @@ If Ollama is run as a systemd service, environment variables should be set using
|
|||||||
|
|
||||||
```ini
|
```ini
|
||||||
[Service]
|
[Service]
|
||||||
Environment="OLLAMA_HOST=0.0.0.0"
|
Environment="OLLAMA_HOST=0.0.0.0:11434"
|
||||||
```
|
```
|
||||||
|
|
||||||
3. Save and exit.
|
3. Save and exit.
|
||||||
|
|
||||||
4. Reload `systemd` and restart Ollama:
|
4. Reload `systemd` and restart Ollama:
|
||||||
|
|
||||||
```bash
|
```shell
|
||||||
systemctl daemon-reload
|
systemctl daemon-reload
|
||||||
systemctl restart ollama
|
systemctl restart ollama
|
||||||
```
|
```
|
||||||
@@ -182,6 +187,13 @@ cloudflared tunnel --url http://localhost:11434 --http-host-header="localhost:11
|
|||||||
|
|
||||||
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`.
|
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`.
|
||||||
|
|
||||||
|
For browser extensions, you'll need to explicitly allow the extension's origin pattern. Set `OLLAMA_ORIGINS` to include `chrome-extension://*`, `moz-extension://*`, and `safari-web-extension://*` if you wish to allow all browser extensions access, or specific extensions as needed:
|
||||||
|
|
||||||
|
```
|
||||||
|
# Allow all Chrome, Firefox, and Safari extensions
|
||||||
|
OLLAMA_ORIGINS=chrome-extension://*,moz-extension://*,safari-web-extension://* ollama serve
|
||||||
|
```
|
||||||
|
|
||||||
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
|
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
|
||||||
|
|
||||||
## Where are models stored?
|
## Where are models stored?
|
||||||
@@ -221,16 +233,19 @@ properties.
|
|||||||
If you are using the API you can preload a model by sending the Ollama server an empty request. This works with both the `/api/generate` and `/api/chat` API endpoints.
|
If you are using the API you can preload a model by sending the Ollama server an empty request. This works with both the `/api/generate` and `/api/chat` API endpoints.
|
||||||
|
|
||||||
To preload the mistral model using the generate endpoint, use:
|
To preload the mistral model using the generate endpoint, use:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/generate -d '{"model": "mistral"}'
|
curl http://localhost:11434/api/generate -d '{"model": "mistral"}'
|
||||||
```
|
```
|
||||||
|
|
||||||
To use the chat completions endpoint, use:
|
To use the chat completions endpoint, use:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/chat -d '{"model": "mistral"}'
|
curl http://localhost:11434/api/chat -d '{"model": "mistral"}'
|
||||||
```
|
```
|
||||||
|
|
||||||
To preload a model using the CLI, use the command:
|
To preload a model using the CLI, use the command:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
ollama run llama3.2 ""
|
ollama run llama3.2 ""
|
||||||
```
|
```
|
||||||
@@ -250,11 +265,13 @@ If you're using the API, use the `keep_alive` parameter with the `/api/generate`
|
|||||||
* '0' which will unload the model immediately after generating a response
|
* '0' which will unload the model immediately after generating a response
|
||||||
|
|
||||||
For example, to preload a model and leave it in memory use:
|
For example, to preload a model and leave it in memory use:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/generate -d '{"model": "llama3.2", "keep_alive": -1}'
|
curl http://localhost:11434/api/generate -d '{"model": "llama3.2", "keep_alive": -1}'
|
||||||
```
|
```
|
||||||
|
|
||||||
To unload the model and free up memory use:
|
To unload the model and free up memory use:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/generate -d '{"model": "llama3.2", "keep_alive": 0}'
|
curl http://localhost:11434/api/generate -d '{"model": "llama3.2", "keep_alive": 0}'
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ Check your compute compatibility to see if your card is supported:
|
|||||||
|
|
||||||
| Compute Capability | Family | Cards |
|
| Compute Capability | Family | Cards |
|
||||||
| ------------------ | ------------------- | ----------------------------------------------------------------------------------------------------------- |
|
| ------------------ | ------------------- | ----------------------------------------------------------------------------------------------------------- |
|
||||||
| 9.0 | NVIDIA | `H100` |
|
| 9.0 | NVIDIA | `H200` `H100` |
|
||||||
| 8.9 | GeForce RTX 40xx | `RTX 4090` `RTX 4080 SUPER` `RTX 4080` `RTX 4070 Ti SUPER` `RTX 4070 Ti` `RTX 4070 SUPER` `RTX 4070` `RTX 4060 Ti` `RTX 4060` |
|
| 8.9 | GeForce RTX 40xx | `RTX 4090` `RTX 4080 SUPER` `RTX 4080` `RTX 4070 Ti SUPER` `RTX 4070 Ti` `RTX 4070 SUPER` `RTX 4070` `RTX 4060 Ti` `RTX 4060` |
|
||||||
| | NVIDIA Professional | `L4` `L40` `RTX 6000` |
|
| | NVIDIA Professional | `L4` `L40` `RTX 6000` |
|
||||||
| 8.6 | GeForce RTX 30xx | `RTX 3090 Ti` `RTX 3090` `RTX 3080 Ti` `RTX 3080` `RTX 3070 Ti` `RTX 3070` `RTX 3060 Ti` `RTX 3060` `RTX 3050 Ti` `RTX 3050` |
|
| 8.6 | GeForce RTX 30xx | `RTX 3090 Ti` `RTX 3090` `RTX 3080 Ti` `RTX 3080` `RTX 3070 Ti` `RTX 3070` `RTX 3060 Ti` `RTX 3060` `RTX 3050 Ti` `RTX 3050` |
|
||||||
|
|||||||
@@ -20,13 +20,13 @@ Make sure that you use the same base model in the `FROM` command as you used to
|
|||||||
|
|
||||||
Now run `ollama create` from the directory where the `Modelfile` was created:
|
Now run `ollama create` from the directory where the `Modelfile` was created:
|
||||||
|
|
||||||
```bash
|
```shell
|
||||||
ollama create my-model
|
ollama create my-model
|
||||||
```
|
```
|
||||||
|
|
||||||
Lastly, test the model:
|
Lastly, test the model:
|
||||||
|
|
||||||
```bash
|
```shell
|
||||||
ollama run my-model
|
ollama run my-model
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ RestartSec=3
|
|||||||
Environment="PATH=$PATH"
|
Environment="PATH=$PATH"
|
||||||
|
|
||||||
[Install]
|
[Install]
|
||||||
WantedBy=default.target
|
WantedBy=multi-user.target
|
||||||
```
|
```
|
||||||
|
|
||||||
Then start the service:
|
Then start the service:
|
||||||
@@ -119,7 +119,7 @@ sudo systemctl status ollama
|
|||||||
|
|
||||||
To customize the installation of Ollama, you can edit the systemd service file or the environment variables by running:
|
To customize the installation of Ollama, you can edit the systemd service file or the environment variables by running:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
sudo systemctl edit ollama
|
sudo systemctl edit ollama
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -186,3 +186,9 @@ sudo rm -r /usr/share/ollama
|
|||||||
sudo userdel ollama
|
sudo userdel ollama
|
||||||
sudo groupdel ollama
|
sudo groupdel ollama
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Remove installed libraries:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
sudo rm -rf /usr/local/lib/ollama
|
||||||
|
```
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ A model file is the blueprint to create and share models with Ollama.
|
|||||||
|
|
||||||
The format of the `Modelfile`:
|
The format of the `Modelfile`:
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
# comment
|
# comment
|
||||||
INSTRUCTION arguments
|
INSTRUCTION arguments
|
||||||
```
|
```
|
||||||
@@ -49,7 +49,7 @@ INSTRUCTION arguments
|
|||||||
|
|
||||||
An example of a `Modelfile` creating a mario blueprint:
|
An example of a `Modelfile` creating a mario blueprint:
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
FROM llama3.2
|
FROM llama3.2
|
||||||
# sets the temperature to 1 [higher is more creative, lower is more coherent]
|
# sets the temperature to 1 [higher is more creative, lower is more coherent]
|
||||||
PARAMETER temperature 1
|
PARAMETER temperature 1
|
||||||
@@ -69,38 +69,44 @@ To use this:
|
|||||||
|
|
||||||
To view the Modelfile of a given model, use the `ollama show --modelfile` command.
|
To view the Modelfile of a given model, use the `ollama show --modelfile` command.
|
||||||
|
|
||||||
```bash
|
```shell
|
||||||
> ollama show --modelfile llama3.2
|
ollama show --modelfile llama3.2
|
||||||
# Modelfile generated by "ollama show"
|
|
||||||
# To build a new Modelfile based on this one, replace the FROM line with:
|
|
||||||
# FROM llama3.2:latest
|
|
||||||
FROM /Users/pdevine/.ollama/models/blobs/sha256-00e1317cbf74d901080d7100f57580ba8dd8de57203072dc6f668324ba545f29
|
|
||||||
TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
|
||||||
|
|
||||||
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
|
||||||
|
|
||||||
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
|
||||||
|
|
||||||
{{ .Response }}<|eot_id|>"""
|
|
||||||
PARAMETER stop "<|start_header_id|>"
|
|
||||||
PARAMETER stop "<|end_header_id|>"
|
|
||||||
PARAMETER stop "<|eot_id|>"
|
|
||||||
PARAMETER stop "<|reserved_special_token"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> **Output**:
|
||||||
|
>
|
||||||
|
> ```
|
||||||
|
> # Modelfile generated by "ollama show"
|
||||||
|
> # To build a new Modelfile based on this one, replace the FROM line with:
|
||||||
|
> # FROM llama3.2:latest
|
||||||
|
> FROM /Users/pdevine/.ollama/models/blobs/sha256-00e1317cbf74d901080d7100f57580ba8dd8de57203072dc6f668324ba545f29
|
||||||
|
> TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
||||||
|
>
|
||||||
|
> {{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
||||||
|
>
|
||||||
|
> {{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
||||||
|
>
|
||||||
|
> {{ .Response }}<|eot_id|>"""
|
||||||
|
> PARAMETER stop "<|start_header_id|>"
|
||||||
|
> PARAMETER stop "<|end_header_id|>"
|
||||||
|
> PARAMETER stop "<|eot_id|>"
|
||||||
|
> PARAMETER stop "<|reserved_special_token"
|
||||||
|
> ```
|
||||||
|
|
||||||
|
|
||||||
## Instructions
|
## Instructions
|
||||||
|
|
||||||
### FROM (Required)
|
### FROM (Required)
|
||||||
|
|
||||||
The `FROM` instruction defines the base model to use when creating a model.
|
The `FROM` instruction defines the base model to use when creating a model.
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
FROM <model name>:<tag>
|
FROM <model name>:<tag>
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Build from existing model
|
#### Build from existing model
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
FROM llama3.2
|
FROM llama3.2
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -111,7 +117,7 @@ Additional models can be found at:
|
|||||||
|
|
||||||
#### Build from a Safetensors model
|
#### Build from a Safetensors model
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
FROM <model directory>
|
FROM <model directory>
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -125,7 +131,7 @@ Currently supported model architectures:
|
|||||||
|
|
||||||
#### Build from a GGUF file
|
#### Build from a GGUF file
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
FROM ./ollama-model.gguf
|
FROM ./ollama-model.gguf
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -136,7 +142,7 @@ The GGUF file location should be specified as an absolute path or relative to th
|
|||||||
|
|
||||||
The `PARAMETER` instruction defines a parameter that can be set when the model is run.
|
The `PARAMETER` instruction defines a parameter that can be set when the model is run.
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
PARAMETER <parameter> <parametervalue>
|
PARAMETER <parameter> <parametervalue>
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -183,7 +189,7 @@ TEMPLATE """{{ if .System }}<|im_start|>system
|
|||||||
|
|
||||||
The `SYSTEM` instruction specifies the system message to be used in the template, if applicable.
|
The `SYSTEM` instruction specifies the system message to be used in the template, if applicable.
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
SYSTEM """<system message>"""
|
SYSTEM """<system message>"""
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -193,7 +199,7 @@ The `ADAPTER` instruction specifies a fine tuned LoRA adapter that should apply
|
|||||||
|
|
||||||
#### Safetensor adapter
|
#### Safetensor adapter
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
ADAPTER <path to safetensor adapter>
|
ADAPTER <path to safetensor adapter>
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -204,7 +210,7 @@ Currently supported Safetensor adapters:
|
|||||||
|
|
||||||
#### GGUF adapter
|
#### GGUF adapter
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
ADAPTER ./ollama-lora.gguf
|
ADAPTER ./ollama-lora.gguf
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -212,7 +218,7 @@ ADAPTER ./ollama-lora.gguf
|
|||||||
|
|
||||||
The `LICENSE` instruction allows you to specify the legal license under which the model used with this Modelfile is shared or distributed.
|
The `LICENSE` instruction allows you to specify the legal license under which the model used with this Modelfile is shared or distributed.
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
LICENSE """
|
LICENSE """
|
||||||
<license text>
|
<license text>
|
||||||
"""
|
"""
|
||||||
@@ -222,7 +228,7 @@ LICENSE """
|
|||||||
|
|
||||||
The `MESSAGE` instruction allows you to specify a message history for the model to use when responding. Use multiple iterations of the MESSAGE command to build up a conversation which will guide the model to answer in a similar way.
|
The `MESSAGE` instruction allows you to specify a message history for the model to use when responding. Use multiple iterations of the MESSAGE command to build up a conversation which will guide the model to answer in a similar way.
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
MESSAGE <role> <message>
|
MESSAGE <role> <message>
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -237,7 +243,7 @@ MESSAGE <role> <message>
|
|||||||
|
|
||||||
#### Example conversation
|
#### Example conversation
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
MESSAGE user Is Toronto in Canada?
|
MESSAGE user Is Toronto in Canada?
|
||||||
MESSAGE assistant yes
|
MESSAGE assistant yes
|
||||||
MESSAGE user Is Sacramento in Canada?
|
MESSAGE user Is Sacramento in Canada?
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# OpenAI compatibility
|
# OpenAI compatibility
|
||||||
|
|
||||||
> **Note:** OpenAI compatibility is experimental and is subject to major adjustments including breaking changes. For fully-featured access to the Ollama API, see the Ollama [Python library](https://github.com/ollama/ollama-python), [JavaScript library](https://github.com/ollama/ollama-js) and [REST API](https://github.com/ollama/ollama/blob/main/docs/api.md).
|
> [!NOTE]
|
||||||
|
> OpenAI compatibility is experimental and is subject to major adjustments including breaking changes. For fully-featured access to the Ollama API, see the Ollama [Python library](https://github.com/ollama/ollama-python), [JavaScript library](https://github.com/ollama/ollama-js) and [REST API](https://github.com/ollama/ollama/blob/main/docs/api.md).
|
||||||
|
|
||||||
Ollama provides experimental compatibility with parts of the [OpenAI API](https://platform.openai.com/docs/api-reference) to help connect existing applications to Ollama.
|
Ollama provides experimental compatibility with parts of the [OpenAI API](https://platform.openai.com/docs/api-reference) to help connect existing applications to Ollama.
|
||||||
|
|
||||||
@@ -59,8 +60,10 @@ embeddings = client.embeddings.create(
|
|||||||
input=["why is the sky blue?", "why is the grass green?"],
|
input=["why is the sky blue?", "why is the grass green?"],
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Structured outputs
|
#### Structured outputs
|
||||||
```py
|
|
||||||
|
```python
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
@@ -319,7 +322,7 @@ ollama pull llama3.2
|
|||||||
|
|
||||||
For tooling that relies on default OpenAI model names such as `gpt-3.5-turbo`, use `ollama cp` to copy an existing model name to a temporary name:
|
For tooling that relies on default OpenAI model names such as `gpt-3.5-turbo`, use `ollama cp` to copy an existing model name to a temporary name:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
ollama cp llama3.2 gpt-3.5-turbo
|
ollama cp llama3.2 gpt-3.5-turbo
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -343,7 +346,7 @@ curl http://localhost:11434/v1/chat/completions \
|
|||||||
|
|
||||||
The OpenAI API does not have a way of setting the context size for a model. If you need to change the context size, create a `Modelfile` which looks like:
|
The OpenAI API does not have a way of setting the context size for a model. If you need to change the context size, create a `Modelfile` which looks like:
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
FROM <some model>
|
FROM <some model>
|
||||||
PARAMETER num_ctx <context size>
|
PARAMETER num_ctx <context size>
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ When you run Ollama in a **container**, the logs go to stdout/stderr in the cont
|
|||||||
```shell
|
```shell
|
||||||
docker logs <container-name>
|
docker logs <container-name>
|
||||||
```
|
```
|
||||||
|
|
||||||
(Use `docker ps` to find the container name)
|
(Use `docker ps` to find the container name)
|
||||||
|
|
||||||
If manually running `ollama serve` in a terminal, the logs will be on that terminal.
|
If manually running `ollama serve` in a terminal, the logs will be on that terminal.
|
||||||
@@ -28,6 +29,7 @@ When you run Ollama on **Windows**, there are a few different locations. You can
|
|||||||
- `explorer %TEMP%` where temporary executable files are stored in one or more `ollama*` directories
|
- `explorer %TEMP%` where temporary executable files are stored in one or more `ollama*` directories
|
||||||
|
|
||||||
To enable additional debug logging to help troubleshoot problems, first **Quit the running app from the tray menu** then in a powershell terminal
|
To enable additional debug logging to help troubleshoot problems, first **Quit the running app from the tray menu** then in a powershell terminal
|
||||||
|
|
||||||
```powershell
|
```powershell
|
||||||
$env:OLLAMA_DEBUG="1"
|
$env:OLLAMA_DEBUG="1"
|
||||||
& "ollama app.exe"
|
& "ollama app.exe"
|
||||||
@@ -49,12 +51,13 @@ Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v11 rocm_v5]
|
|||||||
|
|
||||||
You can set OLLAMA_LLM_LIBRARY to any of the available LLM libraries to bypass autodetection, so for example, if you have a CUDA card, but want to force the CPU LLM library with AVX2 vector support, use:
|
You can set OLLAMA_LLM_LIBRARY to any of the available LLM libraries to bypass autodetection, so for example, if you have a CUDA card, but want to force the CPU LLM library with AVX2 vector support, use:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
OLLAMA_LLM_LIBRARY="cpu_avx2" ollama serve
|
OLLAMA_LLM_LIBRARY="cpu_avx2" ollama serve
|
||||||
```
|
```
|
||||||
|
|
||||||
You can see what features your CPU has with the following.
|
You can see what features your CPU has with the following.
|
||||||
```
|
|
||||||
|
```shell
|
||||||
cat /proc/cpuinfo| grep flags | head -1
|
cat /proc/cpuinfo| grep flags | head -1
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -62,14 +65,18 @@ cat /proc/cpuinfo| grep flags | head -1
|
|||||||
|
|
||||||
If you run into problems on Linux and want to install an older version, or you'd like to try out a pre-release before it's officially released, you can tell the install script which version to install.
|
If you run into problems on Linux and want to install an older version, or you'd like to try out a pre-release before it's officially released, you can tell the install script which version to install.
|
||||||
|
|
||||||
```sh
|
```shell
|
||||||
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION="0.1.29" sh
|
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.5.7 sh
|
||||||
```
|
```
|
||||||
|
|
||||||
## Linux tmp noexec
|
## Linux tmp noexec
|
||||||
|
|
||||||
If your system is configured with the "noexec" flag where Ollama stores its temporary executable files, you can specify an alternate location by setting OLLAMA_TMPDIR to a location writable by the user ollama runs as. For example OLLAMA_TMPDIR=/usr/share/ollama/
|
If your system is configured with the "noexec" flag where Ollama stores its temporary executable files, you can specify an alternate location by setting OLLAMA_TMPDIR to a location writable by the user ollama runs as. For example OLLAMA_TMPDIR=/usr/share/ollama/
|
||||||
|
|
||||||
|
## Linux docker
|
||||||
|
|
||||||
|
If Ollama initially works on the GPU in a docker container, but then switches to running on CPU after some period of time with errors in the server log reporting GPU discovery failures, this can be resolved by disabling systemd cgroup management in Docker. Edit `/etc/docker/daemon.json` on the host and add `"exec-opts": ["native.cgroupdriver=cgroupfs"]` to the docker configuration.
|
||||||
|
|
||||||
## NVIDIA GPU Discovery
|
## NVIDIA GPU Discovery
|
||||||
|
|
||||||
When Ollama starts up, it takes inventory of the GPUs present in the system to determine compatibility and how much VRAM is available. Sometimes this discovery can fail to find your GPUs. In general, running the latest driver will yield the best results.
|
When Ollama starts up, it takes inventory of the GPUs present in the system to determine compatibility and how much VRAM is available. Sometimes this discovery can fail to find your GPUs. In general, running the latest driver will yield the best results.
|
||||||
@@ -97,8 +104,6 @@ On linux, AMD GPU access typically requires `video` and/or `render` group member
|
|||||||
|
|
||||||
When running in a container, in some Linux distributions and container runtimes, the ollama process may be unable to access the GPU. Use `ls -lnd /dev/kfd /dev/dri /dev/dri/*` on the host system to determine the **numeric** group IDs on your system, and pass additional `--group-add ...` arguments to the container so it can access the required devices. For example, in the following output `crw-rw---- 1 0 44 226, 0 Sep 16 16:55 /dev/dri/card0` the group ID column is `44`
|
When running in a container, in some Linux distributions and container runtimes, the ollama process may be unable to access the GPU. Use `ls -lnd /dev/kfd /dev/dri /dev/dri/*` on the host system to determine the **numeric** group IDs on your system, and pass additional `--group-add ...` arguments to the container so it can access the required devices. For example, in the following output `crw-rw---- 1 0 44 226, 0 Sep 16 16:55 /dev/dri/card0` the group ID column is `44`
|
||||||
|
|
||||||
If Ollama initially works on the GPU in a docker container, but then switches to running on CPU after some period of time with errors in the server log reporting GPU discovery failures, this can be resolved by disabling systemd cgroup management in Docker. Edit `/etc/docker/daemon.json` on the host and add `"exec-opts": ["native.cgroupdriver=cgroupfs"]` to the docker configuration.
|
|
||||||
|
|
||||||
If you are experiencing problems getting Ollama to correctly discover or use your GPU for inference, the following may help isolate the failure.
|
If you are experiencing problems getting Ollama to correctly discover or use your GPU for inference, the following may help isolate the failure.
|
||||||
- `AMD_LOG_LEVEL=3` Enable info log levels in the AMD HIP/ROCm libraries. This can help show more detailed error codes that can help troubleshoot problems
|
- `AMD_LOG_LEVEL=3` Enable info log levels in the AMD HIP/ROCm libraries. This can help show more detailed error codes that can help troubleshoot problems
|
||||||
- `OLLAMA_DEBUG=1` During GPU discovery additional information will be reported
|
- `OLLAMA_DEBUG=1` During GPU discovery additional information will be reported
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ If Ollama is already running, Quit the tray application and relaunch it from the
|
|||||||
## API Access
|
## API Access
|
||||||
|
|
||||||
Here's a quick example showing API access from `powershell`
|
Here's a quick example showing API access from `powershell`
|
||||||
|
|
||||||
```powershell
|
```powershell
|
||||||
(Invoke-WebRequest -method POST -Body '{"model":"llama3.2", "prompt":"Why is the sky blue?", "stream": false}' -uri http://localhost:11434/api/generate ).Content | ConvertFrom-json
|
(Invoke-WebRequest -method POST -Body '{"model":"llama3.2", "prompt":"Why is the sky blue?", "stream": false}' -uri http://localhost:11434/api/generate ).Content | ConvertFrom-json
|
||||||
```
|
```
|
||||||
@@ -54,7 +55,7 @@ Here's a quick example showing API access from `powershell`
|
|||||||
## Troubleshooting
|
## Troubleshooting
|
||||||
|
|
||||||
Ollama on Windows stores files in a few different locations. You can view them in
|
Ollama on Windows stores files in a few different locations. You can view them in
|
||||||
the explorer window by hitting `<cmd>+R` and type in:
|
the explorer window by hitting `<Ctrl>+R` and type in:
|
||||||
- `explorer %LOCALAPPDATA%\Ollama` contains logs, and downloaded updates
|
- `explorer %LOCALAPPDATA%\Ollama` contains logs, and downloaded updates
|
||||||
- *app.log* contains most resent logs from the GUI application
|
- *app.log* contains most resent logs from the GUI application
|
||||||
- *server.log* contains the most recent server logs
|
- *server.log* contains the most recent server logs
|
||||||
@@ -80,9 +81,11 @@ help you keep up to date.
|
|||||||
|
|
||||||
If you'd like to install or integrate Ollama as a service, a standalone
|
If you'd like to install or integrate Ollama as a service, a standalone
|
||||||
`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI
|
`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI
|
||||||
and GPU library dependencies for Nvidia and AMD. This allows for embedding
|
and GPU library dependencies for Nvidia. If you have an AMD GPU, also download
|
||||||
Ollama in existing applications, or running it as a system service via `ollama
|
and extract the additional ROCm package `ollama-windows-amd64-rocm.zip` into the
|
||||||
serve` with tools such as [NSSM](https://nssm.cc/).
|
same directory. This allows for embedding Ollama in existing applications, or
|
||||||
|
running it as a system service via `ollama serve` with tools such as
|
||||||
|
[NSSM](https://nssm.cc/).
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> If you are upgrading from a prior version, you should remove the old directories first.
|
> If you are upgrading from a prior version, you should remove the old directories first.
|
||||||
|
|||||||
@@ -53,8 +53,8 @@ func Host() *url.URL {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable.
|
// AllowedOrigins returns a list of allowed origins. AllowedOrigins can be configured via the OLLAMA_ORIGINS environment variable.
|
||||||
func Origins() (origins []string) {
|
func AllowedOrigins() (origins []string) {
|
||||||
if s := Var("OLLAMA_ORIGINS"); s != "" {
|
if s := Var("OLLAMA_ORIGINS"); s != "" {
|
||||||
origins = strings.Split(s, ",")
|
origins = strings.Split(s, ",")
|
||||||
}
|
}
|
||||||
@@ -73,6 +73,7 @@ func Origins() (origins []string) {
|
|||||||
"file://*",
|
"file://*",
|
||||||
"tauri://*",
|
"tauri://*",
|
||||||
"vscode-webview://*",
|
"vscode-webview://*",
|
||||||
|
"vscode-file://*",
|
||||||
)
|
)
|
||||||
|
|
||||||
return origins
|
return origins
|
||||||
@@ -165,6 +166,10 @@ var (
|
|||||||
IntelGPU = Bool("OLLAMA_INTEL_GPU")
|
IntelGPU = Bool("OLLAMA_INTEL_GPU")
|
||||||
// MultiUserCache optimizes prompt caching for multi-user scenarios
|
// MultiUserCache optimizes prompt caching for multi-user scenarios
|
||||||
MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE")
|
MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE")
|
||||||
|
// Enable the new Ollama engine
|
||||||
|
NewEngine = Bool("OLLAMA_NEW_ENGINE")
|
||||||
|
// ContextLength sets the default context length
|
||||||
|
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 2048)
|
||||||
)
|
)
|
||||||
|
|
||||||
func String(s string) func() string {
|
func String(s string) func() string {
|
||||||
@@ -247,9 +252,11 @@ func AsMap() map[string]EnvVar {
|
|||||||
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
|
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
|
||||||
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
|
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
|
||||||
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
|
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
|
||||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"},
|
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
|
||||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||||
|
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 2048)"},
|
||||||
|
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||||
|
|
||||||
// Informational
|
// Informational
|
||||||
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ func TestOrigins(t *testing.T) {
|
|||||||
"file://*",
|
"file://*",
|
||||||
"tauri://*",
|
"tauri://*",
|
||||||
"vscode-webview://*",
|
"vscode-webview://*",
|
||||||
|
"vscode-file://*",
|
||||||
}},
|
}},
|
||||||
{"http://10.0.0.1", []string{
|
{"http://10.0.0.1", []string{
|
||||||
"http://10.0.0.1",
|
"http://10.0.0.1",
|
||||||
@@ -88,6 +89,7 @@ func TestOrigins(t *testing.T) {
|
|||||||
"file://*",
|
"file://*",
|
||||||
"tauri://*",
|
"tauri://*",
|
||||||
"vscode-webview://*",
|
"vscode-webview://*",
|
||||||
|
"vscode-file://*",
|
||||||
}},
|
}},
|
||||||
{"http://172.16.0.1,https://192.168.0.1", []string{
|
{"http://172.16.0.1,https://192.168.0.1", []string{
|
||||||
"http://172.16.0.1",
|
"http://172.16.0.1",
|
||||||
@@ -108,6 +110,7 @@ func TestOrigins(t *testing.T) {
|
|||||||
"file://*",
|
"file://*",
|
||||||
"tauri://*",
|
"tauri://*",
|
||||||
"vscode-webview://*",
|
"vscode-webview://*",
|
||||||
|
"vscode-file://*",
|
||||||
}},
|
}},
|
||||||
{"http://totally.safe,http://definitely.legit", []string{
|
{"http://totally.safe,http://definitely.legit", []string{
|
||||||
"http://totally.safe",
|
"http://totally.safe",
|
||||||
@@ -128,13 +131,14 @@ func TestOrigins(t *testing.T) {
|
|||||||
"file://*",
|
"file://*",
|
||||||
"tauri://*",
|
"tauri://*",
|
||||||
"vscode-webview://*",
|
"vscode-webview://*",
|
||||||
|
"vscode-file://*",
|
||||||
}},
|
}},
|
||||||
}
|
}
|
||||||
for _, tt := range cases {
|
for _, tt := range cases {
|
||||||
t.Run(tt.value, func(t *testing.T) {
|
t.Run(tt.value, func(t *testing.T) {
|
||||||
t.Setenv("OLLAMA_ORIGINS", tt.value)
|
t.Setenv("OLLAMA_ORIGINS", tt.value)
|
||||||
|
|
||||||
if diff := cmp.Diff(Origins(), tt.expect); diff != "" {
|
if diff := cmp.Diff(AllowedOrigins(), tt.expect); diff != "" {
|
||||||
t.Errorf("%s: mismatch (-want +got):\n%s", tt.value, diff)
|
t.Errorf("%s: mismatch (-want +got):\n%s", tt.value, diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -272,3 +276,19 @@ func TestVar(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestContextLength(t *testing.T) {
|
||||||
|
cases := map[string]uint{
|
||||||
|
"": 2048,
|
||||||
|
"4096": 4096,
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range cases {
|
||||||
|
t.Run(k, func(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_CONTEXT_LENGTH", k)
|
||||||
|
if i := ContextLength(); i != v {
|
||||||
|
t.Errorf("%s: expected %d, got %d", k, v, i)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -40,8 +40,6 @@ func HumanBytes(b int64) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case value >= 100:
|
|
||||||
return fmt.Sprintf("%d %s", int(value), unit)
|
|
||||||
case value >= 10:
|
case value >= 10:
|
||||||
return fmt.Sprintf("%d %s", int(value), unit)
|
return fmt.Sprintf("%d %s", int(value), unit)
|
||||||
case value != math.Trunc(value):
|
case value != math.Trunc(value):
|
||||||
|
|||||||
91
format/bytes_test.go
Normal file
91
format/bytes_test.go
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
package format
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHumanBytes(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
input int64
|
||||||
|
expected string
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
// Test bytes (B)
|
||||||
|
{0, "0 B"},
|
||||||
|
{1, "1 B"},
|
||||||
|
{999, "999 B"},
|
||||||
|
|
||||||
|
// Test kilobytes (KB)
|
||||||
|
{1000, "1 KB"},
|
||||||
|
{1500, "1.5 KB"},
|
||||||
|
{999999, "999 KB"},
|
||||||
|
|
||||||
|
// Test megabytes (MB)
|
||||||
|
{1000000, "1 MB"},
|
||||||
|
{1500000, "1.5 MB"},
|
||||||
|
{999999999, "999 MB"},
|
||||||
|
|
||||||
|
// Test gigabytes (GB)
|
||||||
|
{1000000000, "1 GB"},
|
||||||
|
{1500000000, "1.5 GB"},
|
||||||
|
{999999999999, "999 GB"},
|
||||||
|
|
||||||
|
// Test terabytes (TB)
|
||||||
|
{1000000000000, "1 TB"},
|
||||||
|
{1500000000000, "1.5 TB"},
|
||||||
|
{1999999999999, "2.0 TB"},
|
||||||
|
|
||||||
|
// Test fractional values
|
||||||
|
{1234, "1.2 KB"},
|
||||||
|
{1234567, "1.2 MB"},
|
||||||
|
{1234567890, "1.2 GB"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.expected, func(t *testing.T) {
|
||||||
|
result := HumanBytes(tc.input)
|
||||||
|
if result != tc.expected {
|
||||||
|
t.Errorf("Expected %s, got %s", tc.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHumanBytes2(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
input uint64
|
||||||
|
expected string
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
// Test bytes (B)
|
||||||
|
{0, "0 B"},
|
||||||
|
{1, "1 B"},
|
||||||
|
{1023, "1023 B"},
|
||||||
|
|
||||||
|
// Test kibibytes (KiB)
|
||||||
|
{1024, "1.0 KiB"},
|
||||||
|
{1536, "1.5 KiB"},
|
||||||
|
{1048575, "1024.0 KiB"},
|
||||||
|
|
||||||
|
// Test mebibytes (MiB)
|
||||||
|
{1048576, "1.0 MiB"},
|
||||||
|
{1572864, "1.5 MiB"},
|
||||||
|
{1073741823, "1024.0 MiB"},
|
||||||
|
|
||||||
|
// Test gibibytes (GiB)
|
||||||
|
{1073741824, "1.0 GiB"},
|
||||||
|
{1610612736, "1.5 GiB"},
|
||||||
|
{2147483648, "2.0 GiB"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.expected, func(t *testing.T) {
|
||||||
|
result := HumanBytes2(tc.input)
|
||||||
|
if result != tc.expected {
|
||||||
|
t.Errorf("Expected %s, got %s", tc.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,6 +12,9 @@ func TestHumanNumber(t *testing.T) {
|
|||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{0, "0"},
|
{0, "0"},
|
||||||
|
{999, "999"},
|
||||||
|
{1000, "1K"},
|
||||||
|
{1001, "1K"},
|
||||||
{1000000, "1M"},
|
{1000000, "1M"},
|
||||||
{125000000, "125M"},
|
{125000000, "125M"},
|
||||||
{500500000, "500.50M"},
|
{500500000, "500.50M"},
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
package llm
|
package ggml
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/util/bufioutil"
|
"github.com/ollama/ollama/fs/util/bufioutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GGML struct {
|
type GGML struct {
|
||||||
@@ -19,121 +19,160 @@ type GGML struct {
|
|||||||
|
|
||||||
type model interface {
|
type model interface {
|
||||||
KV() KV
|
KV() KV
|
||||||
Tensors() *Tensors
|
Tensors() Tensors
|
||||||
}
|
}
|
||||||
|
|
||||||
type KV map[string]any
|
type KV map[string]any
|
||||||
|
|
||||||
func (kv KV) u64(key string) uint64 {
|
|
||||||
switch v := kv[key].(type) {
|
|
||||||
case uint64:
|
|
||||||
return v
|
|
||||||
case uint32:
|
|
||||||
return uint64(v)
|
|
||||||
case float64:
|
|
||||||
return uint64(v)
|
|
||||||
default:
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (kv KV) Architecture() string {
|
func (kv KV) Architecture() string {
|
||||||
if s, ok := kv["general.architecture"].(string); ok {
|
return kv.String("general.architecture", "unknown")
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
return "unknown"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) Kind() string {
|
func (kv KV) Kind() string {
|
||||||
if s, ok := kv["general.type"].(string); ok {
|
return kv.String("general.type", "unknown")
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
return "unknown"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) ParameterCount() uint64 {
|
func (kv KV) ParameterCount() uint64 {
|
||||||
return kv.u64("general.parameter_count")
|
return keyValue[uint64](kv, "general.parameter_count")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) FileType() fileType {
|
func (kv KV) FileType() fileType {
|
||||||
if u64 := kv.u64("general.file_type"); u64 > 0 {
|
if t := kv.Uint("general.file_type"); t > 0 {
|
||||||
return fileType(uint32(u64))
|
return fileType(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
return fileTypeUnknown
|
return fileTypeUnknown
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) BlockCount() uint64 {
|
func (kv KV) BlockCount() uint64 {
|
||||||
return kv.u64(fmt.Sprintf("%s.block_count", kv.Architecture()))
|
return uint64(kv.Uint("block_count"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) EmbeddingLength() uint64 {
|
||||||
|
return uint64(kv.Uint("embedding_length"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) HeadCount() uint64 {
|
func (kv KV) HeadCount() uint64 {
|
||||||
return kv.u64(fmt.Sprintf("%s.attention.head_count", kv.Architecture()))
|
return uint64(kv.Uint("attention.head_count"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) HeadCountKV() uint64 {
|
func (kv KV) HeadCountKV() uint64 {
|
||||||
if headCountKV := kv.u64(fmt.Sprintf("%s.attention.head_count_kv", kv.Architecture())); headCountKV > 0 {
|
return uint64(kv.Uint("attention.head_count_kv", 1))
|
||||||
return headCountKV
|
|
||||||
}
|
|
||||||
|
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) EmbeddingHeadCount() uint64 {
|
func (kv KV) EmbeddingHeadCount() uint64 {
|
||||||
if heads := kv.HeadCount(); heads > 0 {
|
if heads := kv.HeadCount(); heads > 0 {
|
||||||
return kv.EmbeddingLength() / kv.HeadCount()
|
return kv.EmbeddingLength() / heads
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) EmbeddingHeadCountK() uint64 {
|
func (kv KV) EmbeddingHeadCountK() uint64 {
|
||||||
if k := kv.u64(fmt.Sprintf("%s.attention.key_length", kv.Architecture())); k > 0 {
|
return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCount())))
|
||||||
return k
|
|
||||||
}
|
|
||||||
|
|
||||||
return kv.EmbeddingHeadCount()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) EmbeddingHeadCountV() uint64 {
|
func (kv KV) EmbeddingHeadCountV() uint64 {
|
||||||
if v := kv.u64(fmt.Sprintf("%s.attention.value_length", kv.Architecture())); v > 0 {
|
return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCount())))
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
return kv.EmbeddingHeadCount()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) GQA() uint64 {
|
func (kv KV) GQA() uint64 {
|
||||||
return kv.HeadCount() / kv.HeadCountKV()
|
return kv.HeadCount() / kv.HeadCountKV()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) EmbeddingLength() uint64 {
|
|
||||||
return kv.u64(fmt.Sprintf("%s.embedding_length", kv.Architecture()))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (kv KV) ContextLength() uint64 {
|
func (kv KV) ContextLength() uint64 {
|
||||||
return kv.u64(fmt.Sprintf("%s.context_length", kv.Architecture()))
|
return uint64(kv.Uint("context_length"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) ChatTemplate() string {
|
func (kv KV) ChatTemplate() string {
|
||||||
s, _ := kv["tokenizer.chat_template"].(string)
|
return kv.String("tokenizer.chat_template")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) String(key string, defaultValue ...string) string {
|
||||||
|
return keyValue(kv, key, append(defaultValue, "")...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
|
||||||
|
return keyValue(kv, key, append(defaultValue, 0)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) Float(key string, defaultValue ...float32) float32 {
|
||||||
|
return keyValue(kv, key, append(defaultValue, 0)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) Bool(key string, defaultValue ...bool) bool {
|
||||||
|
return keyValue(kv, key, append(defaultValue, false)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
||||||
|
r := keyValue(kv, key, &array{})
|
||||||
|
s := make([]string, r.size)
|
||||||
|
for i := range r.size {
|
||||||
|
s[i] = r.values[i].(string)
|
||||||
|
}
|
||||||
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
type Tensors struct {
|
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
|
||||||
Items []*Tensor
|
r := keyValue(kv, key, &array{})
|
||||||
Offset uint64
|
s := make([]uint32, r.size)
|
||||||
|
for i := range r.size {
|
||||||
layers map[string]Layer
|
s[i] = uint32(r.values[i].(int32))
|
||||||
layersOnce sync.Once
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts *Tensors) Layers() map[string]Layer {
|
return s
|
||||||
ts.layersOnce.Do(func() {
|
}
|
||||||
ts.layers = make(map[string]Layer)
|
|
||||||
for _, t := range ts.Items {
|
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
|
||||||
|
r := keyValue(kv, key, &array{})
|
||||||
|
s := make([]float32, r.size)
|
||||||
|
for i := range r.size {
|
||||||
|
s[i] = float32(r.values[i].(float32))
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) OllamaEngineRequired() bool {
|
||||||
|
return kv.Architecture() == "gemma3"
|
||||||
|
}
|
||||||
|
|
||||||
|
func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T {
|
||||||
|
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
|
||||||
|
key = kv.Architecture() + "." + key
|
||||||
|
}
|
||||||
|
|
||||||
|
if val, ok := kv[key]; ok {
|
||||||
|
return val.(T)
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Warn("key not found", "key", key, "default", defaultValue[0])
|
||||||
|
return defaultValue[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tensors struct {
|
||||||
|
items []*Tensor
|
||||||
|
Offset uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Tensors) Items(prefix ...string) []*Tensor {
|
||||||
|
if len(prefix) == 0 {
|
||||||
|
return s.items
|
||||||
|
}
|
||||||
|
|
||||||
|
var items []*Tensor
|
||||||
|
for _, t := range s.items {
|
||||||
|
if strings.HasPrefix(t.Name, prefix[0]) {
|
||||||
|
items = append(items, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return items
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensors) GroupLayers() map[string]Layer {
|
||||||
|
layers := make(map[string]Layer)
|
||||||
|
for _, t := range ts.items {
|
||||||
parts := strings.Split(t.Name, ".")
|
parts := strings.Split(t.Name, ".")
|
||||||
if index := slices.IndexFunc(parts, func(s string) bool { return s == "blk" || s == "mm" }); index != -1 {
|
if index := slices.IndexFunc(parts, func(s string) bool { return s == "blk" || s == "mm" }); index != -1 {
|
||||||
if len(parts) > index+2 {
|
if len(parts) > index+2 {
|
||||||
@@ -144,20 +183,19 @@ func (ts *Tensors) Layers() map[string]Layer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := ts.layers[parts[0]]; !ok {
|
if _, ok := layers[parts[0]]; !ok {
|
||||||
ts.layers[parts[0]] = make(Layer)
|
layers[parts[0]] = make(Layer)
|
||||||
}
|
}
|
||||||
|
|
||||||
ts.layers[parts[0]][strings.Join(parts[1:], ".")] = t
|
layers[parts[0]][strings.Join(parts[1:], ".")] = t
|
||||||
}
|
}
|
||||||
})
|
|
||||||
|
|
||||||
return ts.layers
|
return layers
|
||||||
}
|
}
|
||||||
|
|
||||||
type Layer map[string]*Tensor
|
type Layer map[string]*Tensor
|
||||||
|
|
||||||
func (l Layer) size() (size uint64) {
|
func (l Layer) Size() (size uint64) {
|
||||||
for _, t := range l {
|
for _, t := range l {
|
||||||
size += t.Size()
|
size += t.Size()
|
||||||
}
|
}
|
||||||
@@ -186,11 +224,26 @@ func (t Tensor) block() (n int) {
|
|||||||
|
|
||||||
func (t Tensor) blockSize() uint64 {
|
func (t Tensor) blockSize() uint64 {
|
||||||
switch t.Kind {
|
switch t.Kind {
|
||||||
case 0, 1, 24, 25, 26, 27, 28, 30: // F32, F16, I8, I16, I32, I64, F64, BF16
|
case
|
||||||
|
0, // F32
|
||||||
|
1, // F16
|
||||||
|
24, // I8
|
||||||
|
25, // I16
|
||||||
|
26, // I32
|
||||||
|
27, // I64
|
||||||
|
28, // F64
|
||||||
|
30: // BF16
|
||||||
return 1
|
return 1
|
||||||
case 2, 3, 4, 5, 6, 7, 8, 9, 20: // Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, Q8_1, IQ4_NL
|
case
|
||||||
|
2, // Q4_0
|
||||||
|
3, // Q4_1
|
||||||
|
6, // Q5_0
|
||||||
|
7, // Q5_1
|
||||||
|
8, // Q8_0
|
||||||
|
9, // Q8_1
|
||||||
|
20: // IQ4_NL
|
||||||
return 32
|
return 32
|
||||||
default: // All others
|
default:
|
||||||
return 256
|
return 256
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -214,7 +267,7 @@ func (t Tensor) typeSize() uint64 {
|
|||||||
case 8: // Q8_0
|
case 8: // Q8_0
|
||||||
return 2 + blockSize
|
return 2 + blockSize
|
||||||
case 9: // Q8_1
|
case 9: // Q8_1
|
||||||
return 4 + 4 + blockSize
|
return 2 + 2 + blockSize
|
||||||
case 10: // Q2_K
|
case 10: // Q2_K
|
||||||
return blockSize/16 + blockSize/4 + 2 + 2
|
return blockSize/16 + blockSize/4 + 2 + 2
|
||||||
case 11: // Q3_K
|
case 11: // Q3_K
|
||||||
@@ -226,7 +279,7 @@ func (t Tensor) typeSize() uint64 {
|
|||||||
case 14: // Q6_K
|
case 14: // Q6_K
|
||||||
return blockSize/2 + blockSize/4 + blockSize/16 + 2
|
return blockSize/2 + blockSize/4 + blockSize/16 + 2
|
||||||
case 15: // Q8_K
|
case 15: // Q8_K
|
||||||
return 2 + blockSize + 2*blockSize/16
|
return 4 + blockSize + 2*blockSize/16
|
||||||
case 16: // IQ2_XXS
|
case 16: // IQ2_XXS
|
||||||
return 2 + 2*blockSize/8
|
return 2 + 2*blockSize/8
|
||||||
case 17: // IQ2_XS
|
case 17: // IQ2_XS
|
||||||
@@ -274,6 +327,10 @@ func (t Tensor) Size() uint64 {
|
|||||||
return t.parameters() * t.typeSize() / t.blockSize()
|
return t.parameters() * t.typeSize() / t.blockSize()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t Tensor) Type() string {
|
||||||
|
return fileType(t.Kind).String()
|
||||||
|
}
|
||||||
|
|
||||||
type container interface {
|
type container interface {
|
||||||
Name() string
|
Name() string
|
||||||
Decode(io.ReadSeeker) (model, error)
|
Decode(io.ReadSeeker) (model, error)
|
||||||
@@ -295,7 +352,7 @@ const (
|
|||||||
|
|
||||||
var ErrUnsupportedFormat = errors.New("unsupported model format")
|
var ErrUnsupportedFormat = errors.New("unsupported model format")
|
||||||
|
|
||||||
func DetectGGMLType(b []byte) string {
|
func DetectContentType(b []byte) string {
|
||||||
switch binary.LittleEndian.Uint32(b[:4]) {
|
switch binary.LittleEndian.Uint32(b[:4]) {
|
||||||
case FILE_MAGIC_GGML:
|
case FILE_MAGIC_GGML:
|
||||||
return "ggml"
|
return "ggml"
|
||||||
@@ -312,12 +369,12 @@ func DetectGGMLType(b []byte) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// DecodeGGML decodes a GGML model from the given reader.
|
// Decode decodes a GGML model from the given reader.
|
||||||
//
|
//
|
||||||
// It collects array values for arrays with a size less than or equal to
|
// It collects array values for arrays with a size less than or equal to
|
||||||
// maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
|
// maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
|
||||||
// the maxArraySize is negative, all arrays are collected.
|
// the maxArraySize is negative, all arrays are collected.
|
||||||
func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
||||||
if maxArraySize == 0 {
|
if maxArraySize == 0 {
|
||||||
maxArraySize = 1024
|
maxArraySize = 1024
|
||||||
}
|
}
|
||||||
@@ -331,10 +388,6 @@ func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
|||||||
|
|
||||||
var c container
|
var c container
|
||||||
switch magic {
|
switch magic {
|
||||||
case FILE_MAGIC_GGML, FILE_MAGIC_GGMF, FILE_MAGIC_GGJT:
|
|
||||||
return nil, 0, ErrUnsupportedFormat
|
|
||||||
case FILE_MAGIC_GGLA:
|
|
||||||
c = &containerGGLA{}
|
|
||||||
case FILE_MAGIC_GGUF_LE:
|
case FILE_MAGIC_GGUF_LE:
|
||||||
c = &containerGGUF{ByteOrder: binary.LittleEndian, maxArraySize: maxArraySize}
|
c = &containerGGUF{ByteOrder: binary.LittleEndian, maxArraySize: maxArraySize}
|
||||||
case FILE_MAGIC_GGUF_BE:
|
case FILE_MAGIC_GGUF_BE:
|
||||||
@@ -360,22 +413,22 @@ func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
|||||||
}, offset, nil
|
}, offset, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
|
func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
|
||||||
embedding := llm.KV().EmbeddingLength()
|
embedding := f.KV().EmbeddingLength()
|
||||||
heads := llm.KV().HeadCount()
|
heads := f.KV().HeadCount()
|
||||||
headsKV := llm.KV().HeadCountKV()
|
headsKV := f.KV().HeadCountKV()
|
||||||
vocab := uint64(llm.KV()["tokenizer.ggml.tokens"].(*array).size)
|
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array).size)
|
||||||
|
|
||||||
embeddingHeads := llm.KV().EmbeddingHeadCount()
|
embeddingHeads := f.KV().EmbeddingHeadCount()
|
||||||
embeddingHeadsK := llm.KV().EmbeddingHeadCountK()
|
embeddingHeadsK := f.KV().EmbeddingHeadCountK()
|
||||||
embeddingHeadsV := llm.KV().EmbeddingHeadCountV()
|
embeddingHeadsV := f.KV().EmbeddingHeadCountV()
|
||||||
|
|
||||||
layers := llm.Tensors().Layers()
|
layers := f.Tensors().GroupLayers()
|
||||||
|
|
||||||
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
||||||
kv = uint64(float64(context*llm.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
kv = uint64(float64(context*f.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||||
|
|
||||||
switch llm.KV().Architecture() {
|
switch f.KV().Architecture() {
|
||||||
case "llama":
|
case "llama":
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
4*batch*(1+4*embedding+context*(1+heads)),
|
4*batch*(1+4*embedding+context*(1+heads)),
|
||||||
@@ -390,7 +443,7 @@ func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partia
|
|||||||
|
|
||||||
if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
|
if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
|
||||||
// mixtral 8x22b
|
// mixtral 8x22b
|
||||||
ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
|
ff := uint64(f.KV()["llama.feed_forward_length"].(uint32))
|
||||||
partialOffload = max(
|
partialOffload = max(
|
||||||
3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
|
3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
|
||||||
4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
|
4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
|
||||||
@@ -407,11 +460,11 @@ func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partia
|
|||||||
case "mllama":
|
case "mllama":
|
||||||
var visionTokens, tiles uint64 = 1601, 4
|
var visionTokens, tiles uint64 = 1601, 4
|
||||||
|
|
||||||
if crossAttentionLayers, ok := llm.KV()["mllama.attention.cross_attention_layers"].(*array); ok {
|
if crossAttentionLayers, ok := f.KV()["mllama.attention.cross_attention_layers"].(*array); ok {
|
||||||
kv = headsKV *
|
kv = headsKV *
|
||||||
(embeddingHeadsK + embeddingHeadsV) * // one for K, one for V
|
(embeddingHeadsK + embeddingHeadsV) * // one for K, one for V
|
||||||
(2* // sizeof(float16)
|
(2* // sizeof(float16)
|
||||||
(llm.KV().BlockCount()-uint64(crossAttentionLayers.size))* // num non-cross attention layers
|
(f.KV().BlockCount()-uint64(crossAttentionLayers.size))* // num non-cross attention layers
|
||||||
context +
|
context +
|
||||||
4* // sizeof(float32)
|
4* // sizeof(float32)
|
||||||
uint64(crossAttentionLayers.size)* // num cross attention layers
|
uint64(crossAttentionLayers.size)* // num cross attention layers
|
||||||
@@ -426,7 +479,7 @@ func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partia
|
|||||||
)
|
)
|
||||||
|
|
||||||
var ropeFreqsCount uint64
|
var ropeFreqsCount uint64
|
||||||
if ropeFreqs, ok := llm.Tensors().Layers()["rope_freqs"]; ok {
|
if ropeFreqs, ok := f.Tensors().GroupLayers()["rope_freqs"]; ok {
|
||||||
if ropeFreqsWeights, ok := ropeFreqs["weights"]; ok {
|
if ropeFreqsWeights, ok := ropeFreqs["weights"]; ok {
|
||||||
ropeFreqsCount = ropeFreqsWeights.parameters()
|
ropeFreqsCount = ropeFreqsWeights.parameters()
|
||||||
}
|
}
|
||||||
@@ -440,7 +493,7 @@ func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partia
|
|||||||
// vocab graph
|
// vocab graph
|
||||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||||
)
|
)
|
||||||
case "gemma", "gemma2":
|
case "gemma", "gemma2", "gemma3":
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
4*batch*(embedding+vocab),
|
4*batch*(embedding+vocab),
|
||||||
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
|
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
|
||||||
@@ -529,22 +582,71 @@ func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partia
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
|
||||||
|
if llm.KV().Uint("vision.block_count") == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, layer := range llm.Tensors().GroupLayers() {
|
||||||
|
if name == "v" || strings.HasPrefix(name, "v.") {
|
||||||
|
for _, tensor := range layer {
|
||||||
|
weights += tensor.Size()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
imageSize := uint64(llm.KV().Uint("vision.image_size"))
|
||||||
|
patchSize := uint64(llm.KV().Uint("vision.patch_size"))
|
||||||
|
if patchSize == 0 {
|
||||||
|
slog.Warn("unknown patch size for vision model")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
numChannels := uint64(llm.KV().Uint("vision.num_channels"))
|
||||||
|
|
||||||
|
numPatches := (imageSize / patchSize) * (imageSize / patchSize)
|
||||||
|
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
|
||||||
|
numPatches++
|
||||||
|
}
|
||||||
|
|
||||||
|
headCount := uint64(llm.KV().Uint("vision.attention.head_count"))
|
||||||
|
embeddingLength := uint64(llm.KV().Uint("vision.embedding_length"))
|
||||||
|
|
||||||
|
switch llm.KV().Architecture() {
|
||||||
|
case "mllama":
|
||||||
|
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
|
||||||
|
|
||||||
|
maxNumTiles := uint64(llm.KV().Uint("vision.max_num_tiles"))
|
||||||
|
|
||||||
|
graphSize = 4 * (8 +
|
||||||
|
imageSize*imageSize*numChannels*maxNumTiles +
|
||||||
|
embeddingLength*numPatches*maxNumTiles +
|
||||||
|
9*embeddingLength*numPaddedPatches*maxNumTiles +
|
||||||
|
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
|
||||||
|
case "gemma3":
|
||||||
|
graphSize = 4 * (imageSize*imageSize*numChannels +
|
||||||
|
embeddingLength*patchSize +
|
||||||
|
numPatches*numPatches*headCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
return weights, graphSize
|
||||||
|
}
|
||||||
|
|
||||||
// SupportsKVCacheType checks if the requested cache type is supported
|
// SupportsKVCacheType checks if the requested cache type is supported
|
||||||
func (ggml GGML) SupportsKVCacheType(cacheType string) bool {
|
func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
||||||
validKVCacheTypes := []string{"f16", "q8_0", "q4_0"}
|
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)
|
||||||
return slices.Contains(validKVCacheTypes, cacheType)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SupportsFlashAttention checks if the model supports flash attention
|
// SupportsFlashAttention checks if the model supports flash attention
|
||||||
func (ggml GGML) SupportsFlashAttention() bool {
|
func (f GGML) SupportsFlashAttention() bool {
|
||||||
_, isEmbedding := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]
|
_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]
|
||||||
if isEmbedding {
|
if isEmbedding {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check head counts match and are non-zero
|
// Check head counts match and are non-zero
|
||||||
headCountK := ggml.KV().EmbeddingHeadCountK()
|
headCountK := f.KV().EmbeddingHeadCountK()
|
||||||
headCountV := ggml.KV().EmbeddingHeadCountV()
|
headCountV := f.KV().EmbeddingHeadCountV()
|
||||||
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
|
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
|
||||||
}
|
}
|
||||||
|
|
||||||
212
fs/ggml/ggml_test.go
Normal file
212
fs/ggml/ggml_test.go
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
package ggml
|
||||||
|
|
||||||
|
import (
|
||||||
|
"maps"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTensorLayers(t *testing.T) {
|
||||||
|
tensors := make(map[string]*Tensor)
|
||||||
|
for _, name := range []string{
|
||||||
|
"token_embd.weight",
|
||||||
|
"blk.0.attn_k.weight",
|
||||||
|
"blk.0.attn_output.weight",
|
||||||
|
"blk.0.attn_q.weight",
|
||||||
|
"blk.0.attn_v.weight",
|
||||||
|
"blk.0.attn_norm.weight",
|
||||||
|
"blk.0.ffn_down.weight",
|
||||||
|
"blk.0.ffn_gate.weight",
|
||||||
|
"blk.0.ffn_up.weight",
|
||||||
|
"blk.0.ffn_norm.weight",
|
||||||
|
"output_norm.weight",
|
||||||
|
"mm.0.bias",
|
||||||
|
"mm.0.weight",
|
||||||
|
"v.blk.0.attn_k.weight",
|
||||||
|
"v.blk.0.attn_output.weight",
|
||||||
|
"v.blk.0.attn_q.weight",
|
||||||
|
"v.blk.0.attn_v.weight",
|
||||||
|
"v.blk.0.attn_norm.weight",
|
||||||
|
"v.blk.0.ffn_down.weight",
|
||||||
|
"v.blk.0.ffn_gate.weight",
|
||||||
|
"v.blk.0.ffn_up.weight",
|
||||||
|
"v.blk.0.ffn_norm.weight",
|
||||||
|
"v.patch_embd.weight",
|
||||||
|
"v.position_embd.gate",
|
||||||
|
"v.position_embd.weight",
|
||||||
|
} {
|
||||||
|
tensors[name] = &Tensor{Name: name}
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
items []*Tensor
|
||||||
|
want map[string]Layer
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "text",
|
||||||
|
items: slices.Collect(func(yield func(*Tensor) bool) {
|
||||||
|
for k, v := range tensors {
|
||||||
|
if !strings.HasPrefix(k, "mm.") && !strings.HasPrefix(k, "v.") {
|
||||||
|
if !yield(v) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
want: map[string]Layer{
|
||||||
|
"blk.0": {
|
||||||
|
"attn_k.weight": tensors["blk.0.attn_k.weight"],
|
||||||
|
"attn_q.weight": tensors["blk.0.attn_q.weight"],
|
||||||
|
"attn_v.weight": tensors["blk.0.attn_v.weight"],
|
||||||
|
"attn_output.weight": tensors["blk.0.attn_output.weight"],
|
||||||
|
"attn_norm.weight": tensors["blk.0.attn_norm.weight"],
|
||||||
|
"ffn_down.weight": tensors["blk.0.ffn_down.weight"],
|
||||||
|
"ffn_gate.weight": tensors["blk.0.ffn_gate.weight"],
|
||||||
|
"ffn_up.weight": tensors["blk.0.ffn_up.weight"],
|
||||||
|
"ffn_norm.weight": tensors["blk.0.ffn_norm.weight"],
|
||||||
|
},
|
||||||
|
"token_embd": {"weight": tensors["token_embd.weight"]},
|
||||||
|
"output_norm": {"weight": tensors["output_norm.weight"]},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "vision",
|
||||||
|
items: slices.Collect(func(yield func(*Tensor) bool) {
|
||||||
|
for k, v := range tensors {
|
||||||
|
if strings.HasPrefix(k, "mm.") || strings.HasPrefix(k, "v.") {
|
||||||
|
if !yield(v) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
want: map[string]Layer{
|
||||||
|
"mm.0": {
|
||||||
|
"bias": tensors["mm.0.bias"],
|
||||||
|
"weight": tensors["mm.0.weight"],
|
||||||
|
},
|
||||||
|
"v.blk.0": {
|
||||||
|
"attn_k.weight": tensors["v.blk.0.attn_k.weight"],
|
||||||
|
"attn_q.weight": tensors["v.blk.0.attn_q.weight"],
|
||||||
|
"attn_v.weight": tensors["v.blk.0.attn_v.weight"],
|
||||||
|
"attn_output.weight": tensors["v.blk.0.attn_output.weight"],
|
||||||
|
"attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
|
||||||
|
"ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
|
||||||
|
"ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
|
||||||
|
"ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
|
||||||
|
"ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
|
||||||
|
},
|
||||||
|
"v": {
|
||||||
|
"patch_embd.weight": tensors["v.patch_embd.weight"],
|
||||||
|
"position_embd.gate": tensors["v.position_embd.gate"],
|
||||||
|
"position_embd.weight": tensors["v.position_embd.weight"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "vision and text",
|
||||||
|
items: slices.Collect(maps.Values(tensors)),
|
||||||
|
want: map[string]Layer{
|
||||||
|
"blk.0": {
|
||||||
|
"attn_k.weight": tensors["blk.0.attn_k.weight"],
|
||||||
|
"attn_q.weight": tensors["blk.0.attn_q.weight"],
|
||||||
|
"attn_v.weight": tensors["blk.0.attn_v.weight"],
|
||||||
|
"attn_output.weight": tensors["blk.0.attn_output.weight"],
|
||||||
|
"attn_norm.weight": tensors["blk.0.attn_norm.weight"],
|
||||||
|
"ffn_down.weight": tensors["blk.0.ffn_down.weight"],
|
||||||
|
"ffn_gate.weight": tensors["blk.0.ffn_gate.weight"],
|
||||||
|
"ffn_up.weight": tensors["blk.0.ffn_up.weight"],
|
||||||
|
"ffn_norm.weight": tensors["blk.0.ffn_norm.weight"],
|
||||||
|
},
|
||||||
|
"token_embd": {"weight": tensors["token_embd.weight"]},
|
||||||
|
"output_norm": {"weight": tensors["output_norm.weight"]},
|
||||||
|
"mm.0": {
|
||||||
|
"bias": tensors["mm.0.bias"],
|
||||||
|
"weight": tensors["mm.0.weight"],
|
||||||
|
},
|
||||||
|
"v.blk.0": {
|
||||||
|
"attn_k.weight": tensors["v.blk.0.attn_k.weight"],
|
||||||
|
"attn_q.weight": tensors["v.blk.0.attn_q.weight"],
|
||||||
|
"attn_v.weight": tensors["v.blk.0.attn_v.weight"],
|
||||||
|
"attn_output.weight": tensors["v.blk.0.attn_output.weight"],
|
||||||
|
"attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
|
||||||
|
"ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
|
||||||
|
"ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
|
||||||
|
"ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
|
||||||
|
"ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
|
||||||
|
},
|
||||||
|
"v": {
|
||||||
|
"patch_embd.weight": tensors["v.patch_embd.weight"],
|
||||||
|
"position_embd.gate": tensors["v.position_embd.gate"],
|
||||||
|
"position_embd.weight": tensors["v.position_embd.weight"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := Tensors{items: tt.items}.GroupLayers()
|
||||||
|
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||||
|
t.Errorf("unexpected layers (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/blob/a82c9e7c23ef6db48cebfa194dc9cebbc4ac3552/ggml/src/ggml.c#L572
|
||||||
|
func TestTensorTypes(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
kind uint32
|
||||||
|
blockSize uint64
|
||||||
|
typeSize uint64
|
||||||
|
}{
|
||||||
|
{0, 1, 4},
|
||||||
|
{1, 1, 2},
|
||||||
|
{2, 32, 18},
|
||||||
|
{3, 32, 20},
|
||||||
|
{6, 32, 22},
|
||||||
|
{7, 32, 24},
|
||||||
|
{8, 32, 34},
|
||||||
|
{9, 32, 36},
|
||||||
|
{10, 256, 84},
|
||||||
|
{11, 256, 110},
|
||||||
|
{12, 256, 144},
|
||||||
|
{13, 256, 176},
|
||||||
|
{14, 256, 210},
|
||||||
|
{15, 256, 292},
|
||||||
|
{16, 256, 66},
|
||||||
|
{17, 256, 74},
|
||||||
|
{18, 256, 98},
|
||||||
|
{19, 256, 50},
|
||||||
|
{20, 32, 18},
|
||||||
|
{21, 256, 110},
|
||||||
|
{22, 256, 82},
|
||||||
|
{23, 256, 136},
|
||||||
|
{24, 1, 1},
|
||||||
|
{25, 1, 2},
|
||||||
|
{26, 1, 4},
|
||||||
|
{27, 1, 8},
|
||||||
|
{28, 1, 8},
|
||||||
|
{29, 256, 56},
|
||||||
|
{30, 1, 2},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(strconv.Itoa(int(tt.kind)), func(t *testing.T) {
|
||||||
|
tensor := Tensor{Kind: tt.kind}
|
||||||
|
if tensor.blockSize() != tt.blockSize {
|
||||||
|
t.Errorf("unexpected block size: got=%d want=%d", tensor.blockSize(), tt.blockSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tensor.typeSize() != tt.typeSize {
|
||||||
|
t.Errorf("unexpected type size: got=%d want=%d", tensor.typeSize(), tt.typeSize)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package llm
|
package ggml
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@@ -8,10 +8,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"maps"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/exp/maps"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type containerGGUF struct {
|
type containerGGUF struct {
|
||||||
@@ -110,9 +109,9 @@ func (llm *gguf) KV() KV {
|
|||||||
return llm.kv
|
return llm.kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *gguf) Tensors() *Tensors {
|
func (llm *gguf) Tensors() Tensors {
|
||||||
return &Tensors{
|
return Tensors{
|
||||||
Items: llm.tensors,
|
items: llm.tensors,
|
||||||
Offset: llm.tensorOffset,
|
Offset: llm.tensorOffset,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -523,7 +522,7 @@ func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
keys := maps.Keys(kv)
|
keys := slices.Collect(maps.Keys(kv))
|
||||||
slices.Sort(keys)
|
slices.Sort(keys)
|
||||||
|
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package llm
|
package ggml
|
||||||
|
|
||||||
import "fmt"
|
import "fmt"
|
||||||
|
|
||||||
@@ -98,10 +98,10 @@ func ParseFileType(s string) (fileType, error) {
|
|||||||
return fileTypeIQ3_M, nil
|
return fileTypeIQ3_M, nil
|
||||||
case "IQ2_S":
|
case "IQ2_S":
|
||||||
return fileTypeIQ2_S, nil
|
return fileTypeIQ2_S, nil
|
||||||
case "IQ4_XS":
|
|
||||||
return fileTypeIQ4_XS, nil
|
|
||||||
case "IQ2_M":
|
case "IQ2_M":
|
||||||
return fileTypeIQ2_M, nil
|
return fileTypeIQ2_M, nil
|
||||||
|
case "IQ4_XS":
|
||||||
|
return fileTypeIQ4_XS, nil
|
||||||
case "IQ1_M":
|
case "IQ1_M":
|
||||||
return fileTypeIQ1_M, nil
|
return fileTypeIQ1_M, nil
|
||||||
case "BF16":
|
case "BF16":
|
||||||
19
go.mod
19
go.mod
@@ -1,6 +1,6 @@
|
|||||||
module github.com/ollama/ollama
|
module github.com/ollama/ollama
|
||||||
|
|
||||||
go 1.23.4
|
go 1.24.0
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/containerd/console v1.0.3
|
github.com/containerd/console v1.0.3
|
||||||
@@ -11,7 +11,7 @@ require (
|
|||||||
github.com/spf13/cobra v1.7.0
|
github.com/spf13/cobra v1.7.0
|
||||||
github.com/stretchr/testify v1.9.0
|
github.com/stretchr/testify v1.9.0
|
||||||
github.com/x448/float16 v0.8.4
|
github.com/x448/float16 v0.8.4
|
||||||
golang.org/x/sync v0.10.0
|
golang.org/x/sync v0.11.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
@@ -24,7 +24,7 @@ require (
|
|||||||
github.com/nlpodyssey/gopickle v0.3.0
|
github.com/nlpodyssey/gopickle v0.3.0
|
||||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||||
golang.org/x/image v0.22.0
|
golang.org/x/image v0.22.0
|
||||||
gonum.org/v1/gonum v0.15.0
|
golang.org/x/tools v0.30.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
@@ -44,6 +44,7 @@ require (
|
|||||||
github.com/xtgo/set v1.0.0 // indirect
|
github.com/xtgo/set v1.0.0 // indirect
|
||||||
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
|
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
|
||||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||||
|
gonum.org/v1/gonum v0.15.0 // indirect
|
||||||
gorgonia.org/vecf32 v0.9.0 // indirect
|
gorgonia.org/vecf32 v0.9.0 // indirect
|
||||||
gorgonia.org/vecf64 v0.9.0 // indirect
|
gorgonia.org/vecf64 v0.9.0 // indirect
|
||||||
)
|
)
|
||||||
@@ -69,12 +70,12 @@ require (
|
|||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/crypto v0.31.0
|
golang.org/x/crypto v0.33.0
|
||||||
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa
|
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
|
||||||
golang.org/x/net v0.25.0 // indirect
|
golang.org/x/net v0.35.0 // indirect
|
||||||
golang.org/x/sys v0.28.0
|
golang.org/x/sys v0.30.0
|
||||||
golang.org/x/term v0.27.0
|
golang.org/x/term v0.29.0
|
||||||
golang.org/x/text v0.21.0
|
golang.org/x/text v0.22.0
|
||||||
google.golang.org/protobuf v1.34.1
|
google.golang.org/protobuf v1.34.1
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
30
go.sum
30
go.sum
@@ -214,16 +214,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
|||||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
|
||||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
|
||||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
|
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
|
||||||
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ=
|
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4=
|
||||||
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE=
|
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
|
||||||
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
|
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
|
||||||
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
||||||
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||||
@@ -257,8 +257,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
|
|||||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||||
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
@@ -268,8 +268,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
|||||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
||||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
@@ -285,17 +285,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q=
|
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
|
||||||
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
|
golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
|
||||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
|
||||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
@@ -309,6 +309,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
|
|||||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||||
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||||
|
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
|
||||||
|
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
|
||||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
|||||||
22
grammar/bench_test.go
Normal file
22
grammar/bench_test.go
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
//go:build go1.24
|
||||||
|
|
||||||
|
package grammar
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func BenchmarkFromSchema(b *testing.B) {
|
||||||
|
for tt := range testCases(b) {
|
||||||
|
b.Run("", func(b *testing.B) {
|
||||||
|
s := []byte(tt.schema)
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
for b.Loop() {
|
||||||
|
_, err := FromSchema(nil, s)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("GrammarFromSchema: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
227
grammar/grammar.go
Normal file
227
grammar/grammar.go
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
package grammar
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"iter"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/grammar/jsonschema"
|
||||||
|
)
|
||||||
|
|
||||||
|
const jsonTerms = `
|
||||||
|
# Unicode
|
||||||
|
#
|
||||||
|
# Unicode characters can be specified directly in the grammar, for example
|
||||||
|
# hiragana ::= [ぁ-ゟ], or with escapes: 8-bit (\xXX), 16-bit (\uXXXX) or 32-bit
|
||||||
|
# (\UXXXXXXXX).
|
||||||
|
unicode ::= \x{hex}{2} | \u{hex}{4} | \U{hex}{8}
|
||||||
|
|
||||||
|
# JSON grammar from RFC 7159
|
||||||
|
null ::= "null"
|
||||||
|
object ::= "{" (kv ("," kv)*)? "}"
|
||||||
|
array ::= "[" (value ("," value)*)? "]"
|
||||||
|
kv ::= string ":" value
|
||||||
|
integer ::= "0" | [1-9] [0-9]*
|
||||||
|
number ::= "-"? integer frac? exp?
|
||||||
|
frac ::= "." [0-9]+
|
||||||
|
exp ::= ("e" | "E") ("+" | "-") [0-9]+
|
||||||
|
string ::= "\"" char* "\""
|
||||||
|
escape ::= ["/" | "b" | "f" | "n" | "r" | "t" | unicode]
|
||||||
|
char ::= [^"\\] | escape
|
||||||
|
space ::= (" " | "\t" | "\n" | "\r")*
|
||||||
|
hex ::= [0-9] | [a-f] | [A-F]
|
||||||
|
boolean ::= "true" | "false"
|
||||||
|
value ::= object | array | string | number | boolean | "null"
|
||||||
|
|
||||||
|
# User-defined
|
||||||
|
`
|
||||||
|
|
||||||
|
// FromSchema generates a grammar from a JSON schema.
|
||||||
|
func FromSchema(buf []byte, jsonSchema []byte) ([]byte, error) {
|
||||||
|
var s *jsonschema.Schema
|
||||||
|
if err := json.Unmarshal(jsonSchema, &s); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var g builder
|
||||||
|
|
||||||
|
// "root" is the only rule that is guaranteed to exist, so we start
|
||||||
|
// with its length for padding, and then adjust it as we go.
|
||||||
|
g.pad = len("root")
|
||||||
|
for id := range dependencies("root", s) {
|
||||||
|
g.pad = max(g.pad, len(id))
|
||||||
|
}
|
||||||
|
|
||||||
|
g.b.WriteString(jsonTerms)
|
||||||
|
|
||||||
|
ids := make(map[*jsonschema.Schema]string)
|
||||||
|
for id, s := range dependencies("root", s) {
|
||||||
|
ids[s] = id
|
||||||
|
g.define(id)
|
||||||
|
if err := fromSchema(&g, ids, s); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
g.define("root")
|
||||||
|
if err := fromSchema(&g, ids, s); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
g.define("") // finalize the last rule
|
||||||
|
return g.b.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func fromSchema(g *builder, ids map[*jsonschema.Schema]string, s *jsonschema.Schema) error {
|
||||||
|
switch typ := s.EffectiveType(); typ {
|
||||||
|
case "array":
|
||||||
|
if len(s.PrefixItems) == 0 && s.Items == nil {
|
||||||
|
g.u("array")
|
||||||
|
} else {
|
||||||
|
g.q("[")
|
||||||
|
for i, s := range s.PrefixItems {
|
||||||
|
if i > 0 {
|
||||||
|
g.q(",")
|
||||||
|
}
|
||||||
|
g.u(ids[s])
|
||||||
|
}
|
||||||
|
if s.Items != nil {
|
||||||
|
g.u("(")
|
||||||
|
if len(s.PrefixItems) > 0 {
|
||||||
|
g.q(",")
|
||||||
|
}
|
||||||
|
g.u(ids[s.Items])
|
||||||
|
g.u(")*")
|
||||||
|
}
|
||||||
|
g.q("]")
|
||||||
|
}
|
||||||
|
case "object":
|
||||||
|
if len(s.Properties) == 0 {
|
||||||
|
g.u("object")
|
||||||
|
} else {
|
||||||
|
g.q("{")
|
||||||
|
for i, p := range s.Properties {
|
||||||
|
name := ids[p]
|
||||||
|
if i > 0 {
|
||||||
|
g.q(",")
|
||||||
|
}
|
||||||
|
g.q(p.Name)
|
||||||
|
g.q(":")
|
||||||
|
g.u(name)
|
||||||
|
}
|
||||||
|
g.q("}")
|
||||||
|
}
|
||||||
|
case "number":
|
||||||
|
buildConstrainedNumber(g, s)
|
||||||
|
case "string":
|
||||||
|
if len(s.Enum) == 0 {
|
||||||
|
g.u("string")
|
||||||
|
} else {
|
||||||
|
g.u("(")
|
||||||
|
for i, e := range s.Enum {
|
||||||
|
if i > 0 {
|
||||||
|
g.q("|")
|
||||||
|
}
|
||||||
|
g.q(string(e))
|
||||||
|
}
|
||||||
|
g.u(")")
|
||||||
|
}
|
||||||
|
case "boolean", "value", "null", "integer":
|
||||||
|
g.u(typ)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%s: unsupported type %q", s.Name, typ)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// dependencies returns a sequence of all child dependencies of the schema in
|
||||||
|
// post-order.
|
||||||
|
//
|
||||||
|
// The first value is the id/pointer to the dependency, and the second value
|
||||||
|
// is the schema.
|
||||||
|
func dependencies(id string, s *jsonschema.Schema) iter.Seq2[string, *jsonschema.Schema] {
|
||||||
|
return func(yield func(string, *jsonschema.Schema) bool) {
|
||||||
|
for i, p := range s.Properties {
|
||||||
|
id := fmt.Sprintf("%s_%d", id, i)
|
||||||
|
for did, d := range dependencies(id, p) {
|
||||||
|
if !yield(did, d) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !yield(id, p) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i, p := range s.PrefixItems {
|
||||||
|
id := fmt.Sprintf("tuple_%d", i)
|
||||||
|
for did, d := range dependencies(id, p) {
|
||||||
|
id := fmt.Sprintf("%s_%s", id, did)
|
||||||
|
if !yield(id, d) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !yield(id, p) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.Items != nil {
|
||||||
|
id := fmt.Sprintf("%s_tuple_%d", id, len(s.PrefixItems))
|
||||||
|
for did, d := range dependencies(id, s.Items) {
|
||||||
|
if !yield(did, d) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !yield(id, s.Items) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type builder struct {
|
||||||
|
b bytes.Buffer
|
||||||
|
pad int
|
||||||
|
rules int
|
||||||
|
items int
|
||||||
|
}
|
||||||
|
|
||||||
|
// define terminates the current rule, if any, and then either starts a new
|
||||||
|
// rule or does nothing else if the name is empty.
|
||||||
|
func (b *builder) define(name string) {
|
||||||
|
if b.rules > 0 {
|
||||||
|
b.b.WriteString(";\n")
|
||||||
|
}
|
||||||
|
if name == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Fprintf(&b.b, "% -*s", b.pad, name)
|
||||||
|
b.b.WriteString(" ::=")
|
||||||
|
b.rules++
|
||||||
|
b.items = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// quote appends a terminal to the current rule.
|
||||||
|
func (b *builder) q(s string) {
|
||||||
|
if b.items > 0 {
|
||||||
|
b.b.WriteString(" ")
|
||||||
|
}
|
||||||
|
b.b.WriteString(" ")
|
||||||
|
b.b.WriteString(strconv.Quote(s))
|
||||||
|
}
|
||||||
|
|
||||||
|
// u appends a non-terminal to the current rule.
|
||||||
|
func (b *builder) u(s string) {
|
||||||
|
if b.items > 0 {
|
||||||
|
b.b.WriteString(" ")
|
||||||
|
}
|
||||||
|
b.b.WriteString(" ")
|
||||||
|
b.b.WriteString(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildConstrainedNumber(b *builder, s *jsonschema.Schema) {
|
||||||
|
if s.Minimum == 0 && s.Maximum == 0 {
|
||||||
|
b.u("TODO")
|
||||||
|
} else {
|
||||||
|
b.u("number")
|
||||||
|
}
|
||||||
|
}
|
||||||
75
grammar/grammar_test.go
Normal file
75
grammar/grammar_test.go
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
package grammar
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"cmp"
|
||||||
|
"iter"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
_ "embed"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/grammar/internal/diff"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFromSchema(t *testing.T) {
|
||||||
|
for tt := range testCases(t) {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
g, err := FromSchema(nil, []byte(tt.schema))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("FromSchema: %v", err)
|
||||||
|
}
|
||||||
|
got := string(g)
|
||||||
|
got = strings.TrimPrefix(got, jsonTerms)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Logf("schema:\n%s", tt.schema)
|
||||||
|
t.Fatal(string(diff.Diff("got", []byte(got), "want", []byte(tt.want))))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
schema string
|
||||||
|
want string
|
||||||
|
}
|
||||||
|
|
||||||
|
//go:embed testdata/schemas.txt
|
||||||
|
var tests string
|
||||||
|
|
||||||
|
func testCases(t testing.TB) iter.Seq[testCase] {
|
||||||
|
t.Helper()
|
||||||
|
return func(yield func(testCase) bool) {
|
||||||
|
t.Helper()
|
||||||
|
sc := bufio.NewScanner(strings.NewReader(tests))
|
||||||
|
name := ""
|
||||||
|
for sc.Scan() {
|
||||||
|
line := strings.TrimSpace(sc.Text())
|
||||||
|
if line == "" {
|
||||||
|
name = ""
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if line[0] == '#' {
|
||||||
|
name = cmp.Or(name, strings.TrimSpace(line[1:]))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s := sc.Text()
|
||||||
|
g := ""
|
||||||
|
for sc.Scan() {
|
||||||
|
line = strings.TrimSpace(sc.Text())
|
||||||
|
if line == "" || line[0] == '#' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
g += sc.Text() + "\n"
|
||||||
|
}
|
||||||
|
if !yield(testCase{name, s, g}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
name = strings.TrimSpace(strings.TrimPrefix(line, "#"))
|
||||||
|
}
|
||||||
|
if err := sc.Err(); err != nil {
|
||||||
|
t.Fatalf("error reading tests: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
261
grammar/internal/diff/diff.go
Normal file
261
grammar/internal/diff/diff.go
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
// Copyright 2022 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package diff
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A pair is a pair of values tracked for both the x and y side of a diff.
|
||||||
|
// It is typically a pair of line indexes.
|
||||||
|
type pair struct{ x, y int }
|
||||||
|
|
||||||
|
// Diff returns an anchored diff of the two texts old and new
|
||||||
|
// in the “unified diff” format. If old and new are identical,
|
||||||
|
// Diff returns a nil slice (no output).
|
||||||
|
//
|
||||||
|
// Unix diff implementations typically look for a diff with
|
||||||
|
// the smallest number of lines inserted and removed,
|
||||||
|
// which can in the worst case take time quadratic in the
|
||||||
|
// number of lines in the texts. As a result, many implementations
|
||||||
|
// either can be made to run for a long time or cut off the search
|
||||||
|
// after a predetermined amount of work.
|
||||||
|
//
|
||||||
|
// In contrast, this implementation looks for a diff with the
|
||||||
|
// smallest number of “unique” lines inserted and removed,
|
||||||
|
// where unique means a line that appears just once in both old and new.
|
||||||
|
// We call this an “anchored diff” because the unique lines anchor
|
||||||
|
// the chosen matching regions. An anchored diff is usually clearer
|
||||||
|
// than a standard diff, because the algorithm does not try to
|
||||||
|
// reuse unrelated blank lines or closing braces.
|
||||||
|
// The algorithm also guarantees to run in O(n log n) time
|
||||||
|
// instead of the standard O(n²) time.
|
||||||
|
//
|
||||||
|
// Some systems call this approach a “patience diff,” named for
|
||||||
|
// the “patience sorting” algorithm, itself named for a solitaire card game.
|
||||||
|
// We avoid that name for two reasons. First, the name has been used
|
||||||
|
// for a few different variants of the algorithm, so it is imprecise.
|
||||||
|
// Second, the name is frequently interpreted as meaning that you have
|
||||||
|
// to wait longer (to be patient) for the diff, meaning that it is a slower algorithm,
|
||||||
|
// when in fact the algorithm is faster than the standard one.
|
||||||
|
func Diff(oldName string, old []byte, newName string, new []byte) []byte {
|
||||||
|
if bytes.Equal(old, new) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
x := lines(old)
|
||||||
|
y := lines(new)
|
||||||
|
|
||||||
|
// Print diff header.
|
||||||
|
var out bytes.Buffer
|
||||||
|
fmt.Fprintf(&out, "diff %s %s\n", oldName, newName)
|
||||||
|
fmt.Fprintf(&out, "--- %s\n", oldName)
|
||||||
|
fmt.Fprintf(&out, "+++ %s\n", newName)
|
||||||
|
|
||||||
|
// Loop over matches to consider,
|
||||||
|
// expanding each match to include surrounding lines,
|
||||||
|
// and then printing diff chunks.
|
||||||
|
// To avoid setup/teardown cases outside the loop,
|
||||||
|
// tgs returns a leading {0,0} and trailing {len(x), len(y)} pair
|
||||||
|
// in the sequence of matches.
|
||||||
|
var (
|
||||||
|
done pair // printed up to x[:done.x] and y[:done.y]
|
||||||
|
chunk pair // start lines of current chunk
|
||||||
|
count pair // number of lines from each side in current chunk
|
||||||
|
ctext []string // lines for current chunk
|
||||||
|
)
|
||||||
|
for _, m := range tgs(x, y) {
|
||||||
|
if m.x < done.x {
|
||||||
|
// Already handled scanning forward from earlier match.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expand matching lines as far as possible,
|
||||||
|
// establishing that x[start.x:end.x] == y[start.y:end.y].
|
||||||
|
// Note that on the first (or last) iteration we may (or definitely do)
|
||||||
|
// have an empty match: start.x==end.x and start.y==end.y.
|
||||||
|
start := m
|
||||||
|
for start.x > done.x && start.y > done.y && x[start.x-1] == y[start.y-1] {
|
||||||
|
start.x--
|
||||||
|
start.y--
|
||||||
|
}
|
||||||
|
end := m
|
||||||
|
for end.x < len(x) && end.y < len(y) && x[end.x] == y[end.y] {
|
||||||
|
end.x++
|
||||||
|
end.y++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit the mismatched lines before start into this chunk.
|
||||||
|
// (No effect on first sentinel iteration, when start = {0,0}.)
|
||||||
|
for _, s := range x[done.x:start.x] {
|
||||||
|
ctext = append(ctext, "-"+s)
|
||||||
|
count.x++
|
||||||
|
}
|
||||||
|
for _, s := range y[done.y:start.y] {
|
||||||
|
ctext = append(ctext, "+"+s)
|
||||||
|
count.y++
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we're not at EOF and have too few common lines,
|
||||||
|
// the chunk includes all the common lines and continues.
|
||||||
|
const C = 3 // number of context lines
|
||||||
|
if (end.x < len(x) || end.y < len(y)) &&
|
||||||
|
(end.x-start.x < C || (len(ctext) > 0 && end.x-start.x < 2*C)) {
|
||||||
|
for _, s := range x[start.x:end.x] {
|
||||||
|
ctext = append(ctext, " "+s)
|
||||||
|
count.x++
|
||||||
|
count.y++
|
||||||
|
}
|
||||||
|
done = end
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// End chunk with common lines for context.
|
||||||
|
if len(ctext) > 0 {
|
||||||
|
n := end.x - start.x
|
||||||
|
if n > C {
|
||||||
|
n = C
|
||||||
|
}
|
||||||
|
for _, s := range x[start.x : start.x+n] {
|
||||||
|
ctext = append(ctext, " "+s)
|
||||||
|
count.x++
|
||||||
|
count.y++
|
||||||
|
}
|
||||||
|
done = pair{start.x + n, start.y + n}
|
||||||
|
|
||||||
|
// Format and emit chunk.
|
||||||
|
// Convert line numbers to 1-indexed.
|
||||||
|
// Special case: empty file shows up as 0,0 not 1,0.
|
||||||
|
if count.x > 0 {
|
||||||
|
chunk.x++
|
||||||
|
}
|
||||||
|
if count.y > 0 {
|
||||||
|
chunk.y++
|
||||||
|
}
|
||||||
|
fmt.Fprintf(&out, "@@ -%d,%d +%d,%d @@\n", chunk.x, count.x, chunk.y, count.y)
|
||||||
|
for _, s := range ctext {
|
||||||
|
out.WriteString(s)
|
||||||
|
}
|
||||||
|
count.x = 0
|
||||||
|
count.y = 0
|
||||||
|
ctext = ctext[:0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we reached EOF, we're done.
|
||||||
|
if end.x >= len(x) && end.y >= len(y) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise start a new chunk.
|
||||||
|
chunk = pair{end.x - C, end.y - C}
|
||||||
|
for _, s := range x[chunk.x:end.x] {
|
||||||
|
ctext = append(ctext, " "+s)
|
||||||
|
count.x++
|
||||||
|
count.y++
|
||||||
|
}
|
||||||
|
done = end
|
||||||
|
}
|
||||||
|
|
||||||
|
return out.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// lines returns the lines in the file x, including newlines.
|
||||||
|
// If the file does not end in a newline, one is supplied
|
||||||
|
// along with a warning about the missing newline.
|
||||||
|
func lines(x []byte) []string {
|
||||||
|
l := strings.SplitAfter(string(x), "\n")
|
||||||
|
if l[len(l)-1] == "" {
|
||||||
|
l = l[:len(l)-1]
|
||||||
|
} else {
|
||||||
|
// Treat last line as having a message about the missing newline attached,
|
||||||
|
// using the same text as BSD/GNU diff (including the leading backslash).
|
||||||
|
l[len(l)-1] += "\n\\ No newline at end of file\n"
|
||||||
|
}
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
// tgs returns the pairs of indexes of the longest common subsequence
|
||||||
|
// of unique lines in x and y, where a unique line is one that appears
|
||||||
|
// once in x and once in y.
|
||||||
|
//
|
||||||
|
// The longest common subsequence algorithm is as described in
|
||||||
|
// Thomas G. Szymanski, “A Special Case of the Maximal Common
|
||||||
|
// Subsequence Problem,” Princeton TR #170 (January 1975),
|
||||||
|
// available at https://research.swtch.com/tgs170.pdf.
|
||||||
|
func tgs(x, y []string) []pair {
|
||||||
|
// Count the number of times each string appears in a and b.
|
||||||
|
// We only care about 0, 1, many, counted as 0, -1, -2
|
||||||
|
// for the x side and 0, -4, -8 for the y side.
|
||||||
|
// Using negative numbers now lets us distinguish positive line numbers later.
|
||||||
|
m := make(map[string]int)
|
||||||
|
for _, s := range x {
|
||||||
|
if c := m[s]; c > -2 {
|
||||||
|
m[s] = c - 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, s := range y {
|
||||||
|
if c := m[s]; c > -8 {
|
||||||
|
m[s] = c - 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now unique strings can be identified by m[s] = -1+-4.
|
||||||
|
//
|
||||||
|
// Gather the indexes of those strings in x and y, building:
|
||||||
|
// xi[i] = increasing indexes of unique strings in x.
|
||||||
|
// yi[i] = increasing indexes of unique strings in y.
|
||||||
|
// inv[i] = index j such that x[xi[i]] = y[yi[j]].
|
||||||
|
var xi, yi, inv []int
|
||||||
|
for i, s := range y {
|
||||||
|
if m[s] == -1+-4 {
|
||||||
|
m[s] = len(yi)
|
||||||
|
yi = append(yi, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i, s := range x {
|
||||||
|
if j, ok := m[s]; ok && j >= 0 {
|
||||||
|
xi = append(xi, i)
|
||||||
|
inv = append(inv, j)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply Algorithm A from Szymanski's paper.
|
||||||
|
// In those terms, A = J = inv and B = [0, n).
|
||||||
|
// We add sentinel pairs {0,0}, and {len(x),len(y)}
|
||||||
|
// to the returned sequence, to help the processing loop.
|
||||||
|
J := inv
|
||||||
|
n := len(xi)
|
||||||
|
T := make([]int, n)
|
||||||
|
L := make([]int, n)
|
||||||
|
for i := range T {
|
||||||
|
T[i] = n + 1
|
||||||
|
}
|
||||||
|
for i := range n {
|
||||||
|
k := sort.Search(n, func(k int) bool {
|
||||||
|
return T[k] >= J[i]
|
||||||
|
})
|
||||||
|
T[k] = J[i]
|
||||||
|
L[i] = k + 1
|
||||||
|
}
|
||||||
|
k := 0
|
||||||
|
for _, v := range L {
|
||||||
|
if k < v {
|
||||||
|
k = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
seq := make([]pair, 2+k)
|
||||||
|
seq[1+k] = pair{len(x), len(y)} // sentinel at end
|
||||||
|
lastj := n
|
||||||
|
for i := n - 1; i >= 0; i-- {
|
||||||
|
if L[i] == k && J[i] < lastj {
|
||||||
|
seq[k] = pair{xi[i], yi[J[i]]}
|
||||||
|
k--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
seq[0] = pair{0, 0} // sentinel at start
|
||||||
|
return seq
|
||||||
|
}
|
||||||
44
grammar/internal/diff/diff_test.go
Normal file
44
grammar/internal/diff/diff_test.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
// Copyright 2022 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package diff
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/tools/txtar"
|
||||||
|
)
|
||||||
|
|
||||||
|
func clean(text []byte) []byte {
|
||||||
|
text = bytes.ReplaceAll(text, []byte("$\n"), []byte("\n"))
|
||||||
|
text = bytes.TrimSuffix(text, []byte("^D\n"))
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test(t *testing.T) {
|
||||||
|
files, _ := filepath.Glob("testdata/*.txt")
|
||||||
|
if len(files) == 0 {
|
||||||
|
t.Fatalf("no testdata")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, file := range files {
|
||||||
|
t.Run(filepath.Base(file), func(t *testing.T) {
|
||||||
|
a, err := txtar.ParseFile(file)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(a.Files) != 3 || a.Files[2].Name != "diff" {
|
||||||
|
t.Fatalf("%s: want three files, third named \"diff\"", file)
|
||||||
|
}
|
||||||
|
diffs := Diff(a.Files[0].Name, clean(a.Files[0].Data), a.Files[1].Name, clean(a.Files[1].Data))
|
||||||
|
want := clean(a.Files[2].Data)
|
||||||
|
if !bytes.Equal(diffs, want) {
|
||||||
|
t.Fatalf("%s: have:\n%s\nwant:\n%s\n%s", file,
|
||||||
|
diffs, want, Diff("have", diffs, "want", want))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
13
grammar/internal/diff/testdata/allnew.txt
vendored
Normal file
13
grammar/internal/diff/testdata/allnew.txt
vendored
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
-- old --
|
||||||
|
-- new --
|
||||||
|
a
|
||||||
|
b
|
||||||
|
c
|
||||||
|
-- diff --
|
||||||
|
diff old new
|
||||||
|
--- old
|
||||||
|
+++ new
|
||||||
|
@@ -0,0 +1,3 @@
|
||||||
|
+a
|
||||||
|
+b
|
||||||
|
+c
|
||||||
13
grammar/internal/diff/testdata/allold.txt
vendored
Normal file
13
grammar/internal/diff/testdata/allold.txt
vendored
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
-- old --
|
||||||
|
a
|
||||||
|
b
|
||||||
|
c
|
||||||
|
-- new --
|
||||||
|
-- diff --
|
||||||
|
diff old new
|
||||||
|
--- old
|
||||||
|
+++ new
|
||||||
|
@@ -1,3 +0,0 @@
|
||||||
|
-a
|
||||||
|
-b
|
||||||
|
-c
|
||||||
35
grammar/internal/diff/testdata/basic.txt
vendored
Normal file
35
grammar/internal/diff/testdata/basic.txt
vendored
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
Example from Hunt and McIlroy, “An Algorithm for Differential File Comparison.”
|
||||||
|
https://www.cs.dartmouth.edu/~doug/diff.pdf
|
||||||
|
|
||||||
|
-- old --
|
||||||
|
a
|
||||||
|
b
|
||||||
|
c
|
||||||
|
d
|
||||||
|
e
|
||||||
|
f
|
||||||
|
g
|
||||||
|
-- new --
|
||||||
|
w
|
||||||
|
a
|
||||||
|
b
|
||||||
|
x
|
||||||
|
y
|
||||||
|
z
|
||||||
|
e
|
||||||
|
-- diff --
|
||||||
|
diff old new
|
||||||
|
--- old
|
||||||
|
+++ new
|
||||||
|
@@ -1,7 +1,7 @@
|
||||||
|
+w
|
||||||
|
a
|
||||||
|
b
|
||||||
|
-c
|
||||||
|
-d
|
||||||
|
+x
|
||||||
|
+y
|
||||||
|
+z
|
||||||
|
e
|
||||||
|
-f
|
||||||
|
-g
|
||||||
40
grammar/internal/diff/testdata/dups.txt
vendored
Normal file
40
grammar/internal/diff/testdata/dups.txt
vendored
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
-- old --
|
||||||
|
a
|
||||||
|
|
||||||
|
b
|
||||||
|
|
||||||
|
c
|
||||||
|
|
||||||
|
d
|
||||||
|
|
||||||
|
e
|
||||||
|
|
||||||
|
f
|
||||||
|
-- new --
|
||||||
|
a
|
||||||
|
|
||||||
|
B
|
||||||
|
|
||||||
|
C
|
||||||
|
|
||||||
|
d
|
||||||
|
|
||||||
|
e
|
||||||
|
|
||||||
|
f
|
||||||
|
-- diff --
|
||||||
|
diff old new
|
||||||
|
--- old
|
||||||
|
+++ new
|
||||||
|
@@ -1,8 +1,8 @@
|
||||||
|
a
|
||||||
|
$
|
||||||
|
-b
|
||||||
|
-
|
||||||
|
-c
|
||||||
|
+B
|
||||||
|
+
|
||||||
|
+C
|
||||||
|
$
|
||||||
|
d
|
||||||
|
$
|
||||||
38
grammar/internal/diff/testdata/end.txt
vendored
Normal file
38
grammar/internal/diff/testdata/end.txt
vendored
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
-- old --
|
||||||
|
1
|
||||||
|
2
|
||||||
|
3
|
||||||
|
4
|
||||||
|
5
|
||||||
|
6
|
||||||
|
7
|
||||||
|
eight
|
||||||
|
nine
|
||||||
|
ten
|
||||||
|
eleven
|
||||||
|
-- new --
|
||||||
|
1
|
||||||
|
2
|
||||||
|
3
|
||||||
|
4
|
||||||
|
5
|
||||||
|
6
|
||||||
|
7
|
||||||
|
8
|
||||||
|
9
|
||||||
|
10
|
||||||
|
-- diff --
|
||||||
|
diff old new
|
||||||
|
--- old
|
||||||
|
+++ new
|
||||||
|
@@ -5,7 +5,6 @@
|
||||||
|
5
|
||||||
|
6
|
||||||
|
7
|
||||||
|
-eight
|
||||||
|
-nine
|
||||||
|
-ten
|
||||||
|
-eleven
|
||||||
|
+8
|
||||||
|
+9
|
||||||
|
+10
|
||||||
9
grammar/internal/diff/testdata/eof.txt
vendored
Normal file
9
grammar/internal/diff/testdata/eof.txt
vendored
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
-- old --
|
||||||
|
a
|
||||||
|
b
|
||||||
|
c^D
|
||||||
|
-- new --
|
||||||
|
a
|
||||||
|
b
|
||||||
|
c^D
|
||||||
|
-- diff --
|
||||||
18
grammar/internal/diff/testdata/eof1.txt
vendored
Normal file
18
grammar/internal/diff/testdata/eof1.txt
vendored
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
-- old --
|
||||||
|
a
|
||||||
|
b
|
||||||
|
c
|
||||||
|
-- new --
|
||||||
|
a
|
||||||
|
b
|
||||||
|
c^D
|
||||||
|
-- diff --
|
||||||
|
diff old new
|
||||||
|
--- old
|
||||||
|
+++ new
|
||||||
|
@@ -1,3 +1,3 @@
|
||||||
|
a
|
||||||
|
b
|
||||||
|
-c
|
||||||
|
+c
|
||||||
|
\ No newline at end of file
|
||||||
18
grammar/internal/diff/testdata/eof2.txt
vendored
Normal file
18
grammar/internal/diff/testdata/eof2.txt
vendored
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
-- old --
|
||||||
|
a
|
||||||
|
b
|
||||||
|
c^D
|
||||||
|
-- new --
|
||||||
|
a
|
||||||
|
b
|
||||||
|
c
|
||||||
|
-- diff --
|
||||||
|
diff old new
|
||||||
|
--- old
|
||||||
|
+++ new
|
||||||
|
@@ -1,3 +1,3 @@
|
||||||
|
a
|
||||||
|
b
|
||||||
|
-c
|
||||||
|
\ No newline at end of file
|
||||||
|
+c
|
||||||
62
grammar/internal/diff/testdata/long.txt
vendored
Normal file
62
grammar/internal/diff/testdata/long.txt
vendored
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
-- old --
|
||||||
|
1
|
||||||
|
2
|
||||||
|
3
|
||||||
|
4
|
||||||
|
5
|
||||||
|
6
|
||||||
|
7
|
||||||
|
8
|
||||||
|
9
|
||||||
|
10
|
||||||
|
11
|
||||||
|
12
|
||||||
|
13
|
||||||
|
14
|
||||||
|
14½
|
||||||
|
15
|
||||||
|
16
|
||||||
|
17
|
||||||
|
18
|
||||||
|
19
|
||||||
|
20
|
||||||
|
-- new --
|
||||||
|
1
|
||||||
|
2
|
||||||
|
3
|
||||||
|
4
|
||||||
|
5
|
||||||
|
6
|
||||||
|
8
|
||||||
|
9
|
||||||
|
10
|
||||||
|
11
|
||||||
|
12
|
||||||
|
13
|
||||||
|
14
|
||||||
|
17
|
||||||
|
18
|
||||||
|
19
|
||||||
|
20
|
||||||
|
-- diff --
|
||||||
|
diff old new
|
||||||
|
--- old
|
||||||
|
+++ new
|
||||||
|
@@ -4,7 +4,6 @@
|
||||||
|
4
|
||||||
|
5
|
||||||
|
6
|
||||||
|
-7
|
||||||
|
8
|
||||||
|
9
|
||||||
|
10
|
||||||
|
@@ -12,9 +11,6 @@
|
||||||
|
12
|
||||||
|
13
|
||||||
|
14
|
||||||
|
-14½
|
||||||
|
-15
|
||||||
|
-16
|
||||||
|
17
|
||||||
|
18
|
||||||
|
19
|
||||||
5
grammar/internal/diff/testdata/same.txt
vendored
Normal file
5
grammar/internal/diff/testdata/same.txt
vendored
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
-- old --
|
||||||
|
hello world
|
||||||
|
-- new --
|
||||||
|
hello world
|
||||||
|
-- diff --
|
||||||
34
grammar/internal/diff/testdata/start.txt
vendored
Normal file
34
grammar/internal/diff/testdata/start.txt
vendored
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
-- old --
|
||||||
|
e
|
||||||
|
pi
|
||||||
|
4
|
||||||
|
5
|
||||||
|
6
|
||||||
|
7
|
||||||
|
8
|
||||||
|
9
|
||||||
|
10
|
||||||
|
-- new --
|
||||||
|
1
|
||||||
|
2
|
||||||
|
3
|
||||||
|
4
|
||||||
|
5
|
||||||
|
6
|
||||||
|
7
|
||||||
|
8
|
||||||
|
9
|
||||||
|
10
|
||||||
|
-- diff --
|
||||||
|
diff old new
|
||||||
|
--- old
|
||||||
|
+++ new
|
||||||
|
@@ -1,5 +1,6 @@
|
||||||
|
-e
|
||||||
|
-pi
|
||||||
|
+1
|
||||||
|
+2
|
||||||
|
+3
|
||||||
|
4
|
||||||
|
5
|
||||||
|
6
|
||||||
40
grammar/internal/diff/testdata/triv.txt
vendored
Normal file
40
grammar/internal/diff/testdata/triv.txt
vendored
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
Another example from Hunt and McIlroy,
|
||||||
|
“An Algorithm for Differential File Comparison.”
|
||||||
|
https://www.cs.dartmouth.edu/~doug/diff.pdf
|
||||||
|
|
||||||
|
Anchored diff gives up on finding anything,
|
||||||
|
since there are no unique lines.
|
||||||
|
|
||||||
|
-- old --
|
||||||
|
a
|
||||||
|
b
|
||||||
|
c
|
||||||
|
a
|
||||||
|
b
|
||||||
|
b
|
||||||
|
a
|
||||||
|
-- new --
|
||||||
|
c
|
||||||
|
a
|
||||||
|
b
|
||||||
|
a
|
||||||
|
b
|
||||||
|
c
|
||||||
|
-- diff --
|
||||||
|
diff old new
|
||||||
|
--- old
|
||||||
|
+++ new
|
||||||
|
@@ -1,7 +1,6 @@
|
||||||
|
-a
|
||||||
|
-b
|
||||||
|
-c
|
||||||
|
-a
|
||||||
|
-b
|
||||||
|
-b
|
||||||
|
-a
|
||||||
|
+c
|
||||||
|
+a
|
||||||
|
+b
|
||||||
|
+a
|
||||||
|
+b
|
||||||
|
+c
|
||||||
171
grammar/jsonschema/decode.go
Normal file
171
grammar/jsonschema/decode.go
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
package jsonschema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Schema holds a JSON schema.
|
||||||
|
type Schema struct {
|
||||||
|
// Name is the name of the property. For the parent/root property, this
|
||||||
|
// is "root". For child properties, this is the name of the property.
|
||||||
|
Name string `json:"-"`
|
||||||
|
|
||||||
|
// Type is the type of the property.
|
||||||
|
//
|
||||||
|
// TODO: Union types (e.g. make this a []string).
|
||||||
|
Type string
|
||||||
|
|
||||||
|
// PrefixItems is a list of schemas for each item in a tuple. By
|
||||||
|
// default, the tuple is "closed." unless Items is set to true or a
|
||||||
|
// valid Schema.
|
||||||
|
PrefixItems []*Schema
|
||||||
|
|
||||||
|
// Items is the schema for each item in a list.
|
||||||
|
//
|
||||||
|
// If it is missing, or its JSON value is "null" or "false", it is nil.
|
||||||
|
// If the JSON value is "true", it is set to the empty Schema. If the
|
||||||
|
// JSON value is an object, it will be decoded as a Schema.
|
||||||
|
Items *Schema
|
||||||
|
|
||||||
|
// MinItems specifies the minimum number of items allowed in a list.
|
||||||
|
MinItems int
|
||||||
|
|
||||||
|
// MaxItems specifies the maximum number of items allowed in a list.
|
||||||
|
MaxItems int
|
||||||
|
|
||||||
|
// Properties is the schema for each property of an object.
|
||||||
|
Properties []*Schema
|
||||||
|
|
||||||
|
// Format is the format of the property. This is used to validate the
|
||||||
|
// property against a specific format.
|
||||||
|
//
|
||||||
|
// It is the callers responsibility to validate the property against
|
||||||
|
// the format.
|
||||||
|
Format string
|
||||||
|
|
||||||
|
// Minimum specifies the minimum value for numeric properties.
|
||||||
|
Minimum float64
|
||||||
|
|
||||||
|
// Maximum specifies the maximum value for numeric properties.
|
||||||
|
Maximum float64
|
||||||
|
|
||||||
|
// Enum is a list of valid values for the property.
|
||||||
|
Enum []json.RawMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Schema) UnmarshalJSON(data []byte) error {
|
||||||
|
type S Schema
|
||||||
|
w := struct {
|
||||||
|
Properties props
|
||||||
|
Items items
|
||||||
|
*S
|
||||||
|
}{
|
||||||
|
S: (*S)(s),
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &w); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if w.Items.set {
|
||||||
|
s.Items = &w.Items.Schema
|
||||||
|
}
|
||||||
|
s.Properties = w.Properties
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type items struct {
|
||||||
|
Schema
|
||||||
|
set bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *items) UnmarshalJSON(data []byte) error {
|
||||||
|
switch b := data[0]; b {
|
||||||
|
case 't':
|
||||||
|
*s = items{set: true}
|
||||||
|
case '{':
|
||||||
|
type I items
|
||||||
|
if err := json.Unmarshal(data, (*I)(s)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.set = true
|
||||||
|
case 'n', 'f':
|
||||||
|
default:
|
||||||
|
return errors.New("invalid Items")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EffectiveType returns the effective type of the schema. If the Type field is
|
||||||
|
// not empty, it is returned; otherwise:
|
||||||
|
//
|
||||||
|
// - If the schema has both Properties and Items, it returns an empty string.
|
||||||
|
// - If the schema has Properties, it returns "object".
|
||||||
|
// - If the schema has Items, it returns "array".
|
||||||
|
// - If the schema has neither Properties nor Items, it returns "value".
|
||||||
|
//
|
||||||
|
// The returned string is never empty.
|
||||||
|
func (d *Schema) EffectiveType() string {
|
||||||
|
if d.Type == "" {
|
||||||
|
if len(d.Properties) > 0 {
|
||||||
|
return "object"
|
||||||
|
}
|
||||||
|
if len(d.PrefixItems) > 0 || d.Items != nil {
|
||||||
|
return "array"
|
||||||
|
}
|
||||||
|
return "value"
|
||||||
|
}
|
||||||
|
return d.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
// props is an ordered list of properties. The order of the properties
|
||||||
|
// is the order in which they were defined in the schema.
|
||||||
|
type props []*Schema
|
||||||
|
|
||||||
|
var _ json.Unmarshaler = (*props)(nil)
|
||||||
|
|
||||||
|
func (v *props) UnmarshalJSON(data []byte) error {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if data[0] != '{' {
|
||||||
|
return errors.New("expected object")
|
||||||
|
}
|
||||||
|
|
||||||
|
d := json.NewDecoder(bytes.NewReader(data))
|
||||||
|
|
||||||
|
// TODO(bmizerany): Consider DisallowUnknownFields. Currently, we, like
|
||||||
|
// llama.cpp, ignore unknown fields, which could be lead to unexpected
|
||||||
|
// behavior for clients of this package, since they may not be aware
|
||||||
|
// that "additionalFields", "itemsPrefix", etc, are being ignored.
|
||||||
|
//
|
||||||
|
// For now, just do what llama.cpp does.
|
||||||
|
|
||||||
|
t, err := d.Token()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if t != json.Delim('{') {
|
||||||
|
return errors.New("expected object")
|
||||||
|
}
|
||||||
|
for d.More() {
|
||||||
|
// Use the first token (map key) as the property name, then
|
||||||
|
// decode the rest of the object fields into a Schema and
|
||||||
|
// append.
|
||||||
|
t, err := d.Token()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if t == json.Delim('}') {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s := &Schema{
|
||||||
|
Name: t.(string),
|
||||||
|
}
|
||||||
|
if err := d.Decode(s); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*v = append(*v, s)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
104
grammar/jsonschema/decode_test.go
Normal file
104
grammar/jsonschema/decode_test.go
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
package jsonschema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
const testSchemaBasic = `
|
||||||
|
{
|
||||||
|
"properties": {
|
||||||
|
"tupleClosedEmpty": { "prefixItems": [] },
|
||||||
|
"tupleClosedMissing": { "prefixItems": [{}] },
|
||||||
|
"tupleClosedNull": { "prefixItems": [{}], "items": null },
|
||||||
|
"tupleClosedFalse": { "prefixItems": [{}], "items": false },
|
||||||
|
"tupleOpenTrue": { "prefixItems": [{}], "items": true },
|
||||||
|
"tupleOpenEmpty": { "prefixItems": [{}], "items": {} },
|
||||||
|
"tupleOpenTyped": { "prefixItems": [{}], "items": {"type": "boolean"} },
|
||||||
|
"tupleOpenMax": { "prefixItems": [{}], "items": true, "maxItems": 3},
|
||||||
|
|
||||||
|
"array": { "items": {"type": "number"} },
|
||||||
|
|
||||||
|
"null": { "type": "null" },
|
||||||
|
"string": { "type": "string" },
|
||||||
|
"boolean": { "type": "boolean" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`
|
||||||
|
|
||||||
|
func TestSchemaUnmarshal(t *testing.T) {
|
||||||
|
var got *Schema
|
||||||
|
if err := json.Unmarshal([]byte(testSchemaBasic), &got); err != nil {
|
||||||
|
t.Fatalf("Unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
want := &Schema{
|
||||||
|
Properties: []*Schema{
|
||||||
|
{Name: "tupleClosedEmpty", PrefixItems: []*Schema{}, Items: nil},
|
||||||
|
{Name: "tupleClosedMissing", PrefixItems: []*Schema{{}}, Items: nil},
|
||||||
|
{Name: "tupleClosedNull", PrefixItems: []*Schema{{}}, Items: nil},
|
||||||
|
{Name: "tupleClosedFalse", PrefixItems: []*Schema{{}}, Items: nil},
|
||||||
|
|
||||||
|
{Name: "tupleOpenTrue", PrefixItems: []*Schema{{}}, Items: &Schema{}},
|
||||||
|
{Name: "tupleOpenEmpty", PrefixItems: []*Schema{{}}, Items: &Schema{}},
|
||||||
|
{Name: "tupleOpenTyped", PrefixItems: []*Schema{{}}, Items: &Schema{Type: "boolean"}},
|
||||||
|
{Name: "tupleOpenMax", PrefixItems: []*Schema{{}}, Items: &Schema{}, MaxItems: 3},
|
||||||
|
|
||||||
|
{Name: "array", Items: &Schema{Type: "number"}},
|
||||||
|
|
||||||
|
{Name: "null", Type: "null"},
|
||||||
|
{Name: "string", Type: "string"},
|
||||||
|
{Name: "boolean", Type: "boolean"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(want, got); diff != "" {
|
||||||
|
t.Errorf("(-want, +got)\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEffectiveType(t *testing.T) {
|
||||||
|
const schema = `
|
||||||
|
{"properties": {
|
||||||
|
"o": {"type": "object"},
|
||||||
|
"a": {"type": "array"},
|
||||||
|
"n": {"type": "number"},
|
||||||
|
"s": {"type": "string"},
|
||||||
|
"z": {"type": "null"},
|
||||||
|
"b": {"type": "boolean"},
|
||||||
|
|
||||||
|
"t0": {"prefixItems": [{}], "items": {"type": "number"}},
|
||||||
|
"t1": {"items": {"type": "number"}, "maxItems": 3},
|
||||||
|
|
||||||
|
"v": {"maxItems": 3}
|
||||||
|
}}
|
||||||
|
`
|
||||||
|
|
||||||
|
var s *Schema
|
||||||
|
if err := json.Unmarshal([]byte(schema), &s); err != nil {
|
||||||
|
t.Fatalf("json.Unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var got []string
|
||||||
|
for _, p := range s.Properties {
|
||||||
|
got = append(got, p.EffectiveType())
|
||||||
|
}
|
||||||
|
|
||||||
|
want := strings.Fields(`
|
||||||
|
object
|
||||||
|
array
|
||||||
|
number
|
||||||
|
string
|
||||||
|
null
|
||||||
|
boolean
|
||||||
|
array
|
||||||
|
array
|
||||||
|
value
|
||||||
|
`)
|
||||||
|
if !reflect.DeepEqual(want, got) {
|
||||||
|
t.Errorf("\ngot:\n\t%v\nwant:\n\t%v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
76
grammar/testdata/schemas.txt
vendored
Normal file
76
grammar/testdata/schemas.txt
vendored
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
# This file holds tests for JSON schema to EBNF grammar conversions.
|
||||||
|
#
|
||||||
|
# The format is a JSON schema, followed by the expected EBNF grammar. Each test
|
||||||
|
# MAY be preceded by a comment that describes the test (e.g. the test name), followed by
|
||||||
|
# the JSON schema and the expected EBNF grammar. If no comment is present, the test
|
||||||
|
# name the tests number in the file (e.g. "#0", "#1", etc.)
|
||||||
|
#
|
||||||
|
# Blank lines signify the end or start of a new test. Comments can be added
|
||||||
|
# anywhere in the file, but they must be preceded by a '#' character and start at
|
||||||
|
# the beginning of the line.
|
||||||
|
|
||||||
|
# default
|
||||||
|
{}
|
||||||
|
root ::= value;
|
||||||
|
|
||||||
|
{"properties": {}}
|
||||||
|
root ::= value;
|
||||||
|
|
||||||
|
# array
|
||||||
|
{"properties": {"a": {"type": "array", "items": {"type": "string"}}}}
|
||||||
|
root_0_tuple_0 ::= string;
|
||||||
|
root_0 ::= "[" ( root_0_tuple_0 )* "]";
|
||||||
|
root ::= "{" "a" ":" root_0 "}";
|
||||||
|
|
||||||
|
# array with nested array
|
||||||
|
{"type": "array", "items": {"type": "array", "items": {"type": "string"}}}
|
||||||
|
root_tuple_0_tuple_0 ::= string;
|
||||||
|
root_tuple_0 ::= "[" ( root_tuple_0_tuple_0 )* "]";
|
||||||
|
root ::= "[" ( root_tuple_0 )* "]";
|
||||||
|
|
||||||
|
# object
|
||||||
|
{"properties": {"e": {}}}
|
||||||
|
root_0 ::= value;
|
||||||
|
root ::= "{" "e" ":" root_0 "}";
|
||||||
|
|
||||||
|
# object with nested object
|
||||||
|
{"properties": {"o": {"type": "object", "properties": {"e": {}}}}}
|
||||||
|
root_0_0 ::= value;
|
||||||
|
root_0 ::= "{" "e" ":" root_0_0 "}";
|
||||||
|
root ::= "{" "o" ":" root_0 "}";
|
||||||
|
|
||||||
|
# boolean
|
||||||
|
{"type": "boolean"}
|
||||||
|
root ::= boolean;
|
||||||
|
|
||||||
|
# number
|
||||||
|
{"properties": {"n": {"type": "number", "minimum": 123, "maximum": 4567}}}
|
||||||
|
root_0 ::= number;
|
||||||
|
root ::= "{" "n" ":" root_0 "}";
|
||||||
|
|
||||||
|
# string
|
||||||
|
{"type": "string"}
|
||||||
|
root ::= string;
|
||||||
|
|
||||||
|
# string with enum
|
||||||
|
{"type": "string", "enum": ["a", "b", "c"]}
|
||||||
|
root ::= ( "\"a\"" "|" "\"b\"" "|" "\"c\"" );
|
||||||
|
|
||||||
|
# spaces in key
|
||||||
|
{"properties": {"a b": {}}}
|
||||||
|
root_0 ::= value;
|
||||||
|
root ::= "{" "a b" ":" root_0 "}";
|
||||||
|
|
||||||
|
# issue7978
|
||||||
|
{ "type": "object", "properties": { "steps": { "type": "array", "items": { "type": "object", "properties": { "explanation": { "type": "string" }, "output": { "type": "string" } }, "required": [ "explanation", "output" ], "additionalProperties": false } }, "final_answer": { "type": "string" } }, "required": [ "steps", "final_answer" ], "additionalProperties": false }
|
||||||
|
root_0_tuple_0_0 ::= string;
|
||||||
|
root_0_tuple_0_1 ::= string;
|
||||||
|
root_0_tuple_0 ::= "{" "explanation" ":" root_0_tuple_0_0 "," "output" ":" root_0_tuple_0_1 "}";
|
||||||
|
root_0 ::= "[" ( root_0_tuple_0 )* "]";
|
||||||
|
root_1 ::= string;
|
||||||
|
root ::= "{" "steps" ":" root_0 "," "final_answer" ":" root_1 "}";
|
||||||
|
|
||||||
|
# !! # special characters in key
|
||||||
|
# !! {"properties": {"a!b": {}}}
|
||||||
|
# !! !invalid character '!' in key
|
||||||
|
# !!
|
||||||
@@ -66,6 +66,35 @@ func TestIntegrationMllama(t *testing.T) {
|
|||||||
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIntegrationSplitBatch(t *testing.T) {
|
||||||
|
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||||
|
require.NoError(t, err)
|
||||||
|
req := api.GenerateRequest{
|
||||||
|
Model: "gemma3:4b",
|
||||||
|
// Fill up a chunk of the batch so the image will partially spill over into the next one
|
||||||
|
System: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed aliquet, justo in malesuada lobortis, odio ligula volutpat quam, quis faucibus ipsum magna quis sapien. Aliquam in venenatis diam, eu viverra magna. Phasellus imperdiet hendrerit volutpat. Vivamus sem ex, facilisis placerat felis non, dictum elementum est. Phasellus aliquam imperdiet lacus, eget placerat ligula sodales vel. Pellentesque nec auctor mi. Curabitur arcu nisi, faucibus eget nunc id, viverra interdum mi. Curabitur ornare ipsum ex, ac euismod ex aliquam in. Vestibulum id magna at purus accumsan fermentum. Proin scelerisque posuere nunc quis interdum. Maecenas sed mollis nisl. Etiam vitae ipsum interdum, placerat est quis, tincidunt velit. Nullam tempor nibh non lorem volutpat efficitur. Cras laoreet diam imperdiet ipsum auctor bibendum. Suspendisse ultrices urna sed metus sagittis suscipit. Quisque ullamcorper aliquam nibh ut mollis. Aenean dapibus mauris pharetra, venenatis elit ac, hendrerit odio. Cras vestibulum erat tempor, lobortis justo eu, lobortis ipsum. Nam laoreet dapibus sem. Proin vel diam ultrices, elementum ante et, ornare lectus. Proin eu accumsan nisl. Praesent ac ex vitae ipsum vulputate tristique facilisis sit amet lacus. Nullam faucibus magna a pellentesque pretium. Nunc lacinia ullamcorper sollicitudin. Donec vitae accumsan turpis, sed porttitor est. Donec porttitor mi vitae augue faucibus, vel mollis diam tincidunt.",
|
||||||
|
Prompt: "what does the text in this image say?",
|
||||||
|
Stream: &stream,
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 0.0,
|
||||||
|
},
|
||||||
|
Images: []api.ImageData{
|
||||||
|
image,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
|
||||||
|
resp := "the ollam"
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||||
|
// llava models on CPU can be quite slow to start,
|
||||||
|
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
|
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
|
||||||
AAUAAAABAAAAUgEoAAMAAAABAAIAAIdpAAQAAAABAAAAWgAAAAAAAABIAAAAAQAAAEgAAAABAAOgAQADAAAAAQABAACgAgAEAAAAAQAAANKgAwAEAAAAAQAA
|
AAUAAAABAAAAUgEoAAMAAAABAAIAAIdpAAQAAAABAAAAWgAAAAAAAABIAAAAAQAAAEgAAAABAAOgAQADAAAAAQABAACgAgAEAAAAAQAAANKgAwAEAAAAAQAA
|
||||||
AHgAAAAAXdsepgAAAAlwSFlzAAALEwAACxMBAJqcGAAAAVlpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6
|
AHgAAAAAXdsepgAAAAlwSFlzAAALEwAACxMBAJqcGAAAAVlpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6
|
||||||
|
|||||||
71
kvcache/cache.go
Normal file
71
kvcache/cache.go
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
package kvcache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrKvCacheFull = errors.New("could not find a kv cache slot")
|
||||||
|
ErrNotSupported = errors.New("model does not support operation")
|
||||||
|
)
|
||||||
|
|
||||||
|
type Cache interface {
|
||||||
|
// ** used by model implementations **
|
||||||
|
|
||||||
|
// SetLayer sets the active layer of the cache
|
||||||
|
SetLayer(layer int)
|
||||||
|
|
||||||
|
// Get returns the history of key and value tensors plus a mask
|
||||||
|
//
|
||||||
|
// The shape of the tensors is documented in the specific
|
||||||
|
// cache implementation used.
|
||||||
|
Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
|
||||||
|
|
||||||
|
// Put stores a batch of key and value in the cache
|
||||||
|
//
|
||||||
|
// The shape of the tensors is documented in the specific
|
||||||
|
// cache implementation used.
|
||||||
|
Put(ctx ml.Context, key, value ml.Tensor)
|
||||||
|
|
||||||
|
// SetConfig controls optimizations (mostly backend-specific) that may transform
|
||||||
|
// the output of the cache to work better with specific kernels. If not called,
|
||||||
|
// the backend settings will be used. This works well when calling Attention.
|
||||||
|
//
|
||||||
|
// The config can be overridden by models, especially if they require vanilla
|
||||||
|
// output when implementing their own version of attention. To do this, pass
|
||||||
|
// an empty ml.CacheConfig.
|
||||||
|
//
|
||||||
|
// Most models will not need to use this.
|
||||||
|
SetConfig(ml.CacheConfig)
|
||||||
|
|
||||||
|
// ** cache management **
|
||||||
|
|
||||||
|
// Init sets up runtime parameters.
|
||||||
|
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
|
||||||
|
// dtype: The data type for storing cache entries
|
||||||
|
// maxSequences: The maximum number of sequences stored in the cache - across all batches
|
||||||
|
// capacity: The number of cache entries to store, per sequence
|
||||||
|
// maxBatch: The maximum number of tokens that can occur in a single batch
|
||||||
|
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
|
||||||
|
|
||||||
|
// Close closes the cache and frees resources associated with it
|
||||||
|
Close()
|
||||||
|
|
||||||
|
// StartForward is called before the start of the model's forward pass.
|
||||||
|
// For each token in the coming batch, there must be a corresponding
|
||||||
|
// entry in positions and seqs.
|
||||||
|
StartForward(ctx ml.Context, batch input.Batch) error
|
||||||
|
|
||||||
|
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
|
||||||
|
CopyPrefix(srcSeq, dstSeq int, len int32)
|
||||||
|
|
||||||
|
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
|
||||||
|
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
|
||||||
|
//
|
||||||
|
// If an error occurs, the entire context for the sequence should be
|
||||||
|
// removed by calling Remove(seq, 0, math.MaxInt32)
|
||||||
|
Remove(seq int, beginIndex, endIndex int32) error
|
||||||
|
}
|
||||||
683
kvcache/causal.go
Normal file
683
kvcache/causal.go
Normal file
@@ -0,0 +1,683 @@
|
|||||||
|
package kvcache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"math"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
|
||||||
|
|
||||||
|
// Causal cache stores K and V tensors according to their position in the
|
||||||
|
// sequence. Returns the history and a mask for attending to past tokens
|
||||||
|
//
|
||||||
|
// The tensors are of shape embed dim, kv heads, batch size
|
||||||
|
// The mask is of shape history size, batch size
|
||||||
|
type Causal struct {
|
||||||
|
DType ml.DType
|
||||||
|
windowSize int32
|
||||||
|
|
||||||
|
opts CausalOptions
|
||||||
|
|
||||||
|
// config controls mostly backend-specific optimizations
|
||||||
|
config *ml.CacheConfig
|
||||||
|
|
||||||
|
// ** current forward pass **
|
||||||
|
|
||||||
|
// the active layer for Get and Put
|
||||||
|
curLayer int
|
||||||
|
|
||||||
|
// starting location for data storage for this batch
|
||||||
|
curLoc int
|
||||||
|
|
||||||
|
// size of the current batch
|
||||||
|
curBatchSize int
|
||||||
|
|
||||||
|
// mask of the cache as used by this batch
|
||||||
|
curMask ml.Tensor
|
||||||
|
|
||||||
|
// locations in the cache that are needed for this batch
|
||||||
|
curCellRange cellRange
|
||||||
|
|
||||||
|
// curSequences is the sequences corresponding to this pass's entries in the cache
|
||||||
|
curSequences []int
|
||||||
|
|
||||||
|
// curPositions is the positions corresponding to this pass's entries in the cache
|
||||||
|
curPositions []int32
|
||||||
|
|
||||||
|
// ** cache metadata **
|
||||||
|
|
||||||
|
// for each possible location in the cache, stores the position and set of sequences
|
||||||
|
// that reference the data there
|
||||||
|
cells []cacheCell
|
||||||
|
|
||||||
|
// maps from sequence to the range of locations where it is stored in the cache
|
||||||
|
cellRanges map[int]cellRange
|
||||||
|
|
||||||
|
// ** cache data storage **
|
||||||
|
|
||||||
|
shiftFn shiftFn
|
||||||
|
backend ml.Backend
|
||||||
|
ctxs map[int]ml.Context
|
||||||
|
keys, values map[int]ml.Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
type cacheCell struct {
|
||||||
|
pos int32
|
||||||
|
sequences []int
|
||||||
|
}
|
||||||
|
|
||||||
|
type cellRange struct {
|
||||||
|
min int
|
||||||
|
max int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCausalCache(shift shiftFn) *Causal {
|
||||||
|
return &Causal{
|
||||||
|
windowSize: math.MaxInt32,
|
||||||
|
shiftFn: shift,
|
||||||
|
ctxs: make(map[int]ml.Context),
|
||||||
|
keys: make(map[int]ml.Tensor),
|
||||||
|
values: make(map[int]ml.Tensor),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
||||||
|
return &Causal{
|
||||||
|
windowSize: windowSize,
|
||||||
|
shiftFn: shift,
|
||||||
|
ctxs: make(map[int]ml.Context),
|
||||||
|
keys: make(map[int]ml.Tensor),
|
||||||
|
values: make(map[int]ml.Tensor),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||||
|
if c.config == nil {
|
||||||
|
var config ml.CacheConfig
|
||||||
|
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||||
|
config = cc.CacheConfig()
|
||||||
|
}
|
||||||
|
c.config = &config
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.config.CachePadding == 0 {
|
||||||
|
c.config.CachePadding = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.config.MaskBatchPadding == 0 {
|
||||||
|
c.config.MaskBatchPadding = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.config.MaskDType == ml.DTypeOther {
|
||||||
|
c.config.MaskDType = ml.DTypeF32
|
||||||
|
}
|
||||||
|
|
||||||
|
var cacheSize int
|
||||||
|
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize)+maxBatch {
|
||||||
|
cacheSize = maxSequences * capacity
|
||||||
|
} else {
|
||||||
|
cacheSize = maxSequences * (int(c.windowSize) + maxBatch)
|
||||||
|
}
|
||||||
|
cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
||||||
|
c.cells = make([]cacheCell, cacheSize)
|
||||||
|
|
||||||
|
c.DType = dtype
|
||||||
|
c.cellRanges = make(map[int]cellRange)
|
||||||
|
c.backend = backend
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) SetConfig(config ml.CacheConfig) {
|
||||||
|
if c.config != nil {
|
||||||
|
panic("config cannot be changed after being previously set, either by the model or backend")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.config = &config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) Close() {
|
||||||
|
for _, ctx := range c.ctxs {
|
||||||
|
ctx.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error {
|
||||||
|
c.curBatchSize = len(batch.Positions)
|
||||||
|
c.curSequences = batch.Sequences
|
||||||
|
c.curPositions = batch.Positions
|
||||||
|
c.opts.Except = nil
|
||||||
|
|
||||||
|
c.updateSlidingWindow()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
c.curLoc, err = c.findStartLoc()
|
||||||
|
if errors.Is(err, ErrKvCacheFull) {
|
||||||
|
c.defrag()
|
||||||
|
c.curLoc, err = c.findStartLoc()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.curCellRange = newRange()
|
||||||
|
for i, pos := range batch.Positions {
|
||||||
|
seq := batch.Sequences[i]
|
||||||
|
|
||||||
|
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||||
|
|
||||||
|
seqRange, ok := c.cellRanges[seq]
|
||||||
|
if !ok {
|
||||||
|
seqRange = newRange()
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.curLoc+i > seqRange.max {
|
||||||
|
seqRange.max = c.curLoc + i
|
||||||
|
}
|
||||||
|
if seqRange.max > c.curCellRange.max {
|
||||||
|
c.curCellRange.max = seqRange.max
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.curLoc+i < seqRange.min {
|
||||||
|
seqRange.min = c.curLoc + i
|
||||||
|
}
|
||||||
|
if seqRange.min < c.curCellRange.min {
|
||||||
|
c.curCellRange.min = seqRange.min
|
||||||
|
}
|
||||||
|
c.cellRanges[seq] = seqRange
|
||||||
|
}
|
||||||
|
|
||||||
|
c.curMask, err = c.buildMask(ctx)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRange() cellRange {
|
||||||
|
return cellRange{
|
||||||
|
min: math.MaxInt,
|
||||||
|
max: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the first contiguous block of at least curBatchSize
|
||||||
|
func (c *Causal) findStartLoc() (int, error) {
|
||||||
|
var start, count int
|
||||||
|
for i := range c.cells {
|
||||||
|
if len(c.cells[i].sequences) == 0 {
|
||||||
|
count++
|
||||||
|
if count >= c.curBatchSize {
|
||||||
|
return start, nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
start = i + 1
|
||||||
|
count = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) updateSlidingWindow() {
|
||||||
|
if c.windowSize == math.MaxInt32 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a map of unique sequences to the lowest position in that sequence
|
||||||
|
lowestPos := make(map[int]int32)
|
||||||
|
for i := range c.curPositions {
|
||||||
|
seq := c.curSequences[i]
|
||||||
|
|
||||||
|
pos, ok := lowestPos[seq]
|
||||||
|
if !ok {
|
||||||
|
pos = c.curPositions[i]
|
||||||
|
} else if c.curPositions[i] < pos {
|
||||||
|
pos = c.curPositions[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
lowestPos[seq] = pos
|
||||||
|
}
|
||||||
|
|
||||||
|
// delete any entries that are beyond the window of the oldest position in the sequence
|
||||||
|
for seq, pos := range lowestPos {
|
||||||
|
oldRange, ok := c.cellRanges[seq]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
newRange := newRange()
|
||||||
|
|
||||||
|
for i := oldRange.min; i <= oldRange.max; i++ {
|
||||||
|
if slices.Contains(c.cells[i].sequences, seq) {
|
||||||
|
if c.cells[i].pos < pos-c.windowSize {
|
||||||
|
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||||
|
} else {
|
||||||
|
newRange.min = min(newRange.min, i)
|
||||||
|
newRange.max = max(newRange.max, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.cellRanges[seq] = newRange
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func roundDown(length, pad int) int {
|
||||||
|
return (length / pad) * pad
|
||||||
|
}
|
||||||
|
|
||||||
|
func roundUp(length, pad int) int {
|
||||||
|
return ((length + pad - 1) / pad) * pad
|
||||||
|
}
|
||||||
|
|
||||||
|
// Builds a mask of history x batch indicating whether for each token in the batch the
|
||||||
|
// token in the history should apply. This is based on both the sequence and causality (the
|
||||||
|
// position of the history is not ahead of the token in the batch).
|
||||||
|
func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
||||||
|
// Align and pad the two dimensions as required by the backend
|
||||||
|
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
|
||||||
|
|
||||||
|
c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
||||||
|
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||||
|
|
||||||
|
length := c.curCellRange.max - c.curCellRange.min + 1
|
||||||
|
mask := make([]float32, batchSize*length)
|
||||||
|
|
||||||
|
for i := range c.curBatchSize {
|
||||||
|
enabled := !slices.Contains(c.opts.Except, i)
|
||||||
|
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
||||||
|
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
||||||
|
(enabled && c.cells[j].pos > c.curPositions[i]) ||
|
||||||
|
c.cells[j].pos < c.curPositions[i]-c.windowSize {
|
||||||
|
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mask out any padding tokens we added. For padding that we added to the cache history, this
|
||||||
|
// has already been masked out because the sequence doesn't match.
|
||||||
|
for i := c.curBatchSize * length; i < len(mask); i++ {
|
||||||
|
mask[i] = float32(math.Inf(-1))
|
||||||
|
}
|
||||||
|
|
||||||
|
maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.config.MaskDType != ml.DTypeF32 {
|
||||||
|
out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
|
||||||
|
ctx.Forward(maskTensor.Copy(ctx, out))
|
||||||
|
maskTensor = out
|
||||||
|
}
|
||||||
|
|
||||||
|
return maskTensor, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
|
||||||
|
for i, key := range c.keys {
|
||||||
|
if key == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
kHeadDim := key.Dim(0)
|
||||||
|
numKVHeads := key.Dim(1)
|
||||||
|
rowSize := key.Stride(2)
|
||||||
|
|
||||||
|
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
|
||||||
|
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
|
||||||
|
|
||||||
|
value := c.values[i]
|
||||||
|
var vSrcView, vDstView ml.Tensor
|
||||||
|
if c.config.PermutedV {
|
||||||
|
vHeadDim := value.Dim(1)
|
||||||
|
elemSize := value.Stride(0)
|
||||||
|
|
||||||
|
vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
|
||||||
|
vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
|
||||||
|
} else {
|
||||||
|
vHeadDim := value.Dim(0)
|
||||||
|
rowSize := value.Stride(2)
|
||||||
|
|
||||||
|
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
|
||||||
|
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Forward(
|
||||||
|
kSrcView.Copy(ctx, kDstView),
|
||||||
|
vSrcView.Copy(ctx, vDstView),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) defrag() {
|
||||||
|
slog.Debug("defragmenting kv cache")
|
||||||
|
|
||||||
|
// Defrag strategy:
|
||||||
|
// - Search for empty holes at the beginning of the cache,
|
||||||
|
// filling them with active data starting at the end
|
||||||
|
// - If there are contiguous elements that need to be moved,
|
||||||
|
// combine them into a single operation by holding new moves
|
||||||
|
// until we see that the next one is non-contiguous
|
||||||
|
// - Fill up the context with the maximum number of operations it
|
||||||
|
// can hold then compute that and continue with a new context
|
||||||
|
//
|
||||||
|
// We could try to optimize placement by grouping blocks from
|
||||||
|
// the same sequences together but most likely the next forward
|
||||||
|
// pass will disrupt this anyways, so the real world benefit
|
||||||
|
// seems limited as this time.
|
||||||
|
|
||||||
|
ctx := c.backend.NewContext()
|
||||||
|
|
||||||
|
// For every move, 6 tensors are required per layer (2 views and a
|
||||||
|
// copy for each of k and v). We also need to refer to the original
|
||||||
|
// k and v cache tensors - once per layer, not per move.
|
||||||
|
layers := 0
|
||||||
|
for _, key := range c.keys {
|
||||||
|
if key == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
layers++
|
||||||
|
}
|
||||||
|
|
||||||
|
maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
|
||||||
|
moves := 0
|
||||||
|
|
||||||
|
var pendingSrc, pendingDst, pendingLen int
|
||||||
|
src := len(c.cells) - 1
|
||||||
|
|
||||||
|
for dst := 0; dst < src; dst++ {
|
||||||
|
if len(c.cells[dst].sequences) == 0 {
|
||||||
|
for ; src > dst; src-- {
|
||||||
|
if len(c.cells[src].sequences) != 0 {
|
||||||
|
c.cells[dst] = c.cells[src]
|
||||||
|
c.cells[src] = cacheCell{}
|
||||||
|
|
||||||
|
if pendingLen > 0 {
|
||||||
|
if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen {
|
||||||
|
pendingSrc = src
|
||||||
|
pendingLen++
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
|
||||||
|
moves++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pendingSrc = src
|
||||||
|
pendingDst = dst
|
||||||
|
pendingLen = 1
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if moves >= maxMoves {
|
||||||
|
ctx.Compute()
|
||||||
|
ctx.Close()
|
||||||
|
ctx = c.backend.NewContext()
|
||||||
|
|
||||||
|
moves = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if pendingLen > 0 {
|
||||||
|
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
|
||||||
|
moves++
|
||||||
|
}
|
||||||
|
|
||||||
|
if moves > 0 {
|
||||||
|
ctx.Compute()
|
||||||
|
}
|
||||||
|
ctx.Close()
|
||||||
|
|
||||||
|
// Reset range metadata
|
||||||
|
for seq := range c.cellRanges {
|
||||||
|
seqRange := newRange()
|
||||||
|
|
||||||
|
for i, cell := range c.cells {
|
||||||
|
if slices.Contains(cell.sequences, seq) {
|
||||||
|
if i < seqRange.min {
|
||||||
|
seqRange.min = i
|
||||||
|
}
|
||||||
|
if i > seqRange.max {
|
||||||
|
seqRange.max = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.cellRanges[seq] = seqRange
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) SetLayer(layer int) {
|
||||||
|
c.curLayer = layer
|
||||||
|
}
|
||||||
|
|
||||||
|
type CausalOptions struct {
|
||||||
|
// Enabled controls whether the causal mask is generated for a particular index in a batch
|
||||||
|
Except []int
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCausal disables causal mask generation for a particular range of indicies in
|
||||||
|
// the current batch for subsequent calls to Get. The state resets for the next forward pass.
|
||||||
|
func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
|
||||||
|
if !slices.Equal(c.opts.Except, opts.Except) {
|
||||||
|
c.opts = opts
|
||||||
|
if ctx != nil {
|
||||||
|
var err error
|
||||||
|
c.curMask, err = c.buildMask(ctx)
|
||||||
|
if err != nil {
|
||||||
|
// This error should never occur because we have previously built a mask with the same shape
|
||||||
|
panic(fmt.Errorf("SetCausal: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||||
|
key := c.keys[c.curLayer]
|
||||||
|
value := c.values[c.curLayer]
|
||||||
|
|
||||||
|
kHeadDim := key.Dim(0)
|
||||||
|
numKVHeads := key.Dim(1)
|
||||||
|
rowSize := key.Stride(2)
|
||||||
|
cachedSize := c.curMask.Dim(0)
|
||||||
|
|
||||||
|
key = key.View(ctx, rowSize*c.curCellRange.min,
|
||||||
|
kHeadDim, key.Stride(1),
|
||||||
|
numKVHeads, key.Stride(2),
|
||||||
|
cachedSize,
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.config.PermutedV {
|
||||||
|
vHeadDim := value.Dim(1)
|
||||||
|
elemSize := value.Stride(0)
|
||||||
|
|
||||||
|
value = value.View(ctx, elemSize*c.curCellRange.min,
|
||||||
|
cachedSize, value.Stride(1),
|
||||||
|
vHeadDim, value.Stride(2),
|
||||||
|
numKVHeads,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
vHeadDim := value.Dim(0)
|
||||||
|
rowSize := value.Stride(2)
|
||||||
|
|
||||||
|
value = value.View(ctx, rowSize*c.curCellRange.min,
|
||||||
|
vHeadDim, value.Stride(1),
|
||||||
|
numKVHeads, value.Stride(2),
|
||||||
|
cachedSize,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, value, c.curMask
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||||
|
kHeadDim := key.Dim(0)
|
||||||
|
vHeadDim := value.Dim(0)
|
||||||
|
numKVHeads := key.Dim(1)
|
||||||
|
batchSize := key.Dim(2)
|
||||||
|
|
||||||
|
if c.curBatchSize != batchSize {
|
||||||
|
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||||
|
c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := c.keys[c.curLayer]; !ok {
|
||||||
|
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, len(c.cells))
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := c.values[c.curLayer]; !ok {
|
||||||
|
if c.config.PermutedV {
|
||||||
|
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vHeadDim, numKVHeads)
|
||||||
|
} else {
|
||||||
|
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, len(c.cells))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rowSize := c.keys[c.curLayer].Stride(2)
|
||||||
|
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
|
||||||
|
|
||||||
|
if c.config.PermutedV {
|
||||||
|
elemSize := c.values[c.curLayer].Stride(0)
|
||||||
|
|
||||||
|
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||||
|
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
|
||||||
|
} else {
|
||||||
|
rowSize := c.values[c.curLayer].Stride(2)
|
||||||
|
|
||||||
|
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||||
|
seqRange := newRange()
|
||||||
|
|
||||||
|
for i := range c.cells {
|
||||||
|
// Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end
|
||||||
|
if slices.Contains(c.cells[i].sequences, dstSeq) {
|
||||||
|
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq })
|
||||||
|
}
|
||||||
|
|
||||||
|
if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len {
|
||||||
|
c.cells[i].sequences = append(c.cells[i].sequences, dstSeq)
|
||||||
|
if i < seqRange.min {
|
||||||
|
seqRange.min = i
|
||||||
|
}
|
||||||
|
if i > seqRange.max {
|
||||||
|
seqRange.max = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.cellRanges[dstSeq] = seqRange
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||||
|
if c.shiftFn == nil {
|
||||||
|
return ErrNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := c.backend.NewContext()
|
||||||
|
defer ctx.Close()
|
||||||
|
|
||||||
|
seqRange := c.cellRanges[seq]
|
||||||
|
size := seqRange.max - seqRange.min + 1
|
||||||
|
|
||||||
|
offsets := make([]int32, size)
|
||||||
|
for i := range offsets {
|
||||||
|
cell := c.cells[seqRange.min+i]
|
||||||
|
|
||||||
|
if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
|
||||||
|
offsets[i] = offset
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kShift, err := ctx.Input().FromIntSlice(offsets, len(offsets))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, key := range c.keys {
|
||||||
|
if key == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
kHeadDim := key.Dim(0)
|
||||||
|
numKVHeads := key.Dim(1)
|
||||||
|
rowSize := key.Stride(2)
|
||||||
|
|
||||||
|
key = key.View(ctx, rowSize*seqRange.min,
|
||||||
|
kHeadDim, key.Stride(1),
|
||||||
|
numKVHeads, key.Stride(2),
|
||||||
|
size,
|
||||||
|
)
|
||||||
|
|
||||||
|
roped, err := c.shiftFn(ctx, i, key, kShift)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Forward(roped.Copy(ctx, key))
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Compute()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
|
||||||
|
var offset int32
|
||||||
|
if endIndex != math.MaxInt32 {
|
||||||
|
offset = beginIndex - endIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
seqRange := newRange()
|
||||||
|
|
||||||
|
for i := range c.cells {
|
||||||
|
if slices.Contains(c.cells[i].sequences, seq) {
|
||||||
|
if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex {
|
||||||
|
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||||
|
} else {
|
||||||
|
if c.cells[i].pos >= endIndex {
|
||||||
|
if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
|
||||||
|
// TODO(jessegross): Need to be careful about data shared between sequences
|
||||||
|
return errors.New("shifting on cells shared by multiple sequences not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.cells[i].pos += offset
|
||||||
|
}
|
||||||
|
if i < seqRange.min {
|
||||||
|
seqRange.min = i
|
||||||
|
}
|
||||||
|
if i > seqRange.max {
|
||||||
|
seqRange.max = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if seqRange == newRange() {
|
||||||
|
delete(c.cellRanges, seq)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.cellRanges[seq] = seqRange
|
||||||
|
|
||||||
|
if endIndex != math.MaxInt32 {
|
||||||
|
err := c.shift(seq, endIndex+offset, offset)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
543
kvcache/causal_test.go
Normal file
543
kvcache/causal_test.go
Normal file
@@ -0,0 +1,543 @@
|
|||||||
|
package kvcache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
in []float32
|
||||||
|
inShape []int
|
||||||
|
seqs []int
|
||||||
|
pos []int32
|
||||||
|
expected []float32
|
||||||
|
expectedShape []int
|
||||||
|
expectedMask []float32
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStore(t *testing.T) {
|
||||||
|
backend := &testBackend{}
|
||||||
|
cache := NewCausalCache(nil)
|
||||||
|
defer cache.Close()
|
||||||
|
|
||||||
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
name: "FirstBatch",
|
||||||
|
in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
|
||||||
|
inShape: []int{2, 3, 4},
|
||||||
|
seqs: []int{0, 0, 0, 0},
|
||||||
|
pos: []int32{0, 1, 2, 3},
|
||||||
|
expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
|
||||||
|
expectedShape: []int{2, 3, 4},
|
||||||
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SecondBatch",
|
||||||
|
in: []float32{115, 215, 125, 225, 135, 235},
|
||||||
|
inShape: []int{2, 3, 1},
|
||||||
|
seqs: []int{0},
|
||||||
|
pos: []int32{4},
|
||||||
|
expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
|
||||||
|
expectedShape: []int{2, 3, 5},
|
||||||
|
expectedMask: []float32{0, 0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCache(t, backend, cache, tests)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSWA(t *testing.T) {
|
||||||
|
backend := &testBackend{}
|
||||||
|
cache := NewSWACache(1, nil)
|
||||||
|
defer cache.Close()
|
||||||
|
|
||||||
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
name: "FirstBatch",
|
||||||
|
in: []float32{1, 2, 3, 4},
|
||||||
|
inShape: []int{1, 1, 4},
|
||||||
|
seqs: []int{0, 0, 0, 0},
|
||||||
|
pos: []int32{0, 1, 2, 3},
|
||||||
|
expected: []float32{1, 2, 3, 4},
|
||||||
|
expectedShape: []int{1, 1, 4},
|
||||||
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SecondBatch",
|
||||||
|
in: []float32{5, 6},
|
||||||
|
inShape: []int{1, 1, 2},
|
||||||
|
seqs: []int{0, 0},
|
||||||
|
pos: []int32{4, 5},
|
||||||
|
expected: []float32{5, 6, 3, 4},
|
||||||
|
expectedShape: []int{1, 1, 4},
|
||||||
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCache(t, backend, cache, tests)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSequences(t *testing.T) {
|
||||||
|
backend := &testBackend{}
|
||||||
|
cache := NewCausalCache(nil)
|
||||||
|
defer cache.Close()
|
||||||
|
|
||||||
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
name: "FirstBatch",
|
||||||
|
in: []float32{1, 2, 3, 4},
|
||||||
|
inShape: []int{1, 1, 4},
|
||||||
|
seqs: []int{0, 0, 1, 1},
|
||||||
|
pos: []int32{0, 1, 0, 1},
|
||||||
|
expected: []float32{1, 2, 3, 4},
|
||||||
|
expectedShape: []int{1, 1, 4},
|
||||||
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SecondBatch",
|
||||||
|
in: []float32{5, 6},
|
||||||
|
inShape: []int{1, 1, 2},
|
||||||
|
seqs: []int{0, 1},
|
||||||
|
pos: []int32{2, 2},
|
||||||
|
expected: []float32{1, 2, 3, 4, 5, 6},
|
||||||
|
expectedShape: []int{1, 1, 6},
|
||||||
|
expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCache(t, backend, cache, tests)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemove(t *testing.T) {
|
||||||
|
backend := &testBackend{}
|
||||||
|
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
|
return key.Add(ctx, shift), nil
|
||||||
|
})
|
||||||
|
defer cache.Close()
|
||||||
|
|
||||||
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
name: "FirstBatch",
|
||||||
|
in: []float32{1, 2, 3, 4},
|
||||||
|
inShape: []int{1, 1, 4},
|
||||||
|
seqs: []int{0, 0, 1, 1},
|
||||||
|
pos: []int32{0, 1, 0, 1},
|
||||||
|
expected: []float32{1, 2, 3, 4},
|
||||||
|
expectedShape: []int{1, 1, 4},
|
||||||
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCache(t, backend, cache, tests)
|
||||||
|
|
||||||
|
err := cache.Remove(0, 1, math.MaxInt32)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests = []testCase{
|
||||||
|
{
|
||||||
|
name: "RemoveEnd",
|
||||||
|
in: []float32{5, 6},
|
||||||
|
inShape: []int{1, 1, 2},
|
||||||
|
seqs: []int{0, 1},
|
||||||
|
pos: []int32{1, 2},
|
||||||
|
expected: []float32{1, 2, 3, 4, 5, 6},
|
||||||
|
expectedShape: []int{1, 1, 6},
|
||||||
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCache(t, backend, cache, tests)
|
||||||
|
|
||||||
|
err = cache.Remove(0, 0, 1)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests = []testCase{
|
||||||
|
{
|
||||||
|
name: "RemoveMiddle",
|
||||||
|
in: []float32{7, 8},
|
||||||
|
inShape: []int{1, 1, 2},
|
||||||
|
seqs: []int{0, 0},
|
||||||
|
pos: []int32{1, 2},
|
||||||
|
expected: []float32{7, 8, 3, 4, 4},
|
||||||
|
expectedShape: []int{1, 1, 5},
|
||||||
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCache(t, backend, cache, tests)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefrag(t *testing.T) {
|
||||||
|
backend := &testBackend{}
|
||||||
|
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
|
return key.Add(ctx, shift), nil
|
||||||
|
})
|
||||||
|
defer cache.Close()
|
||||||
|
|
||||||
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
name: "FirstBatch",
|
||||||
|
in: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||||
|
inShape: []int{1, 1, 16},
|
||||||
|
seqs: []int{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||||
|
pos: []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
|
||||||
|
expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||||
|
expectedShape: []int{1, 1, 16},
|
||||||
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCache(t, backend, cache, tests)
|
||||||
|
|
||||||
|
err := cache.Remove(0, 2, 4)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = cache.Remove(0, 13, math.MaxInt32)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests = []testCase{
|
||||||
|
{
|
||||||
|
name: "Defrag",
|
||||||
|
in: []float32{17, 18, 19},
|
||||||
|
inShape: []int{1, 1, 3},
|
||||||
|
seqs: []int{0, 0, 0},
|
||||||
|
pos: []int32{16, 17, 18},
|
||||||
|
expected: []float32{1, 2, 12, 13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 17, 18, 19},
|
||||||
|
expectedShape: []int{1, 1, 16},
|
||||||
|
expectedMask: []float32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCache(t, backend, cache, tests)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopy(t *testing.T) {
|
||||||
|
backend := &testBackend{}
|
||||||
|
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
|
||||||
|
defer cache.Close()
|
||||||
|
|
||||||
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
name: "FirstBatch",
|
||||||
|
in: []float32{1, 2, 3, 4},
|
||||||
|
inShape: []int{1, 1, 4},
|
||||||
|
seqs: []int{0, 0, 0, 0},
|
||||||
|
pos: []int32{0, 1, 2, 3},
|
||||||
|
expected: []float32{1, 2, 3, 4},
|
||||||
|
expectedShape: []int{1, 1, 4},
|
||||||
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCache(t, backend, cache, tests)
|
||||||
|
|
||||||
|
cache.CopyPrefix(0, 1, 2)
|
||||||
|
|
||||||
|
tests = []testCase{
|
||||||
|
{
|
||||||
|
name: "Copy",
|
||||||
|
in: []float32{5, 6},
|
||||||
|
inShape: []int{1, 1, 2},
|
||||||
|
seqs: []int{1, 1},
|
||||||
|
pos: []int32{3, 4},
|
||||||
|
expected: []float32{1, 2, 3, 4, 5, 6},
|
||||||
|
expectedShape: []int{1, 1, 6},
|
||||||
|
expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCache(t, backend, cache, tests)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
context := backend.NewContext()
|
||||||
|
defer context.Close()
|
||||||
|
|
||||||
|
err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cache.SetLayer(0)
|
||||||
|
tensor, _ := context.FromFloatSlice(test.in, test.inShape...)
|
||||||
|
cache.Put(context, tensor, tensor)
|
||||||
|
|
||||||
|
out, _, mask := cache.Get(context)
|
||||||
|
|
||||||
|
context.Forward(out, mask).Compute(out, mask)
|
||||||
|
|
||||||
|
if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) {
|
||||||
|
t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type testBackend struct{}
|
||||||
|
|
||||||
|
func (b *testBackend) Config() ml.Config {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *testBackend) Get(name string) ml.Tensor {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *testBackend) NewContext() ml.Context {
|
||||||
|
return &testContext{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *testBackend) NewContextSize(int) ml.Context {
|
||||||
|
return &testContext{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *testBackend) SystemInfo() string {
|
||||||
|
return "not implemented"
|
||||||
|
}
|
||||||
|
|
||||||
|
type testContext struct{}
|
||||||
|
|
||||||
|
func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
||||||
|
total := 0
|
||||||
|
|
||||||
|
if len(shape) > 0 {
|
||||||
|
total = 1
|
||||||
|
for _, s := range shape {
|
||||||
|
total *= s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||||
|
return c.Empty(dtype, shape...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
||||||
|
t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
|
||||||
|
|
||||||
|
copy(t.data, s)
|
||||||
|
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
||||||
|
f := make([]float32, len(s))
|
||||||
|
for i := range f {
|
||||||
|
f[i] = float32(s[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
out, _ := c.FromFloatSlice(f, shape...)
|
||||||
|
out.(*testTensor).dtype = ml.DTypeI32
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testContext) Input() ml.Context { return c }
|
||||||
|
func (c *testContext) Output() ml.Context { return c }
|
||||||
|
func (c *testContext) Layer(int) ml.Context { return c }
|
||||||
|
|
||||||
|
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
||||||
|
|
||||||
|
func (c *testContext) Compute(...ml.Tensor) {}
|
||||||
|
|
||||||
|
func (c *testContext) MaxGraphNodes() int {
|
||||||
|
return 10
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testContext) Close() {}
|
||||||
|
|
||||||
|
type testTensor struct {
|
||||||
|
dtype ml.DType
|
||||||
|
elementSize int
|
||||||
|
data []float32
|
||||||
|
shape []int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Dim(n int) int {
|
||||||
|
return t.shape[n]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Stride(n int) int {
|
||||||
|
stride := t.elementSize
|
||||||
|
for i := range n {
|
||||||
|
stride *= t.shape[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
return stride
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Shape() []int {
|
||||||
|
return t.shape
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) DType() ml.DType {
|
||||||
|
return t.dtype
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Bytes() []byte {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Floats() []float32 {
|
||||||
|
out := make([]float32, len(t.data))
|
||||||
|
copy(out, t.data)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
|
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
|
||||||
|
|
||||||
|
for i := range out.data {
|
||||||
|
out.data[i] = t.data[i] + t2.(*testTensor).data[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Softmax(ctx ml.Context) ml.Tensor {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) LayerNorm(ctx ml.Context, weight, bias ml.Tensor, eps float32) ml.Tensor {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) RMSNorm(ctx ml.Context, weight ml.Tensor, eps float32) ml.Tensor {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) GELU(ctx ml.Context) ml.Tensor {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) SILU(ctx ml.Context) ml.Tensor {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||||
|
offset /= t.elementSize
|
||||||
|
|
||||||
|
var s []int
|
||||||
|
|
||||||
|
switch len(shape) {
|
||||||
|
case 1:
|
||||||
|
s = []int{shape[0]}
|
||||||
|
case 5:
|
||||||
|
s = []int{shape[0], shape[2], shape[4]}
|
||||||
|
default:
|
||||||
|
panic("unsupported number of dimensions")
|
||||||
|
}
|
||||||
|
|
||||||
|
context := &testContext{}
|
||||||
|
|
||||||
|
view := context.Empty(t.dtype, s...).(*testTensor)
|
||||||
|
view.data = t.data[offset : offset+len(view.data)]
|
||||||
|
|
||||||
|
return view
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
|
copy(t2.(*testTensor).data, t.data)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
143
kvcache/encoder.go
Normal file
143
kvcache/encoder.go
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
package kvcache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Encoder cache stores K and V tensors that are position independent
|
||||||
|
//
|
||||||
|
// The tensors can be of any shape and will be returned as they were stored
|
||||||
|
// The mask is currently always nil
|
||||||
|
//
|
||||||
|
// Not currently safe for multiple sequences
|
||||||
|
type EncoderCache struct {
|
||||||
|
// config controls mostly backend-specific optimizations
|
||||||
|
config *ml.CacheConfig
|
||||||
|
|
||||||
|
// ** current forward pass **
|
||||||
|
|
||||||
|
// the active layer for Get and Put
|
||||||
|
curLayer int
|
||||||
|
|
||||||
|
// if something is stored during this pass, this
|
||||||
|
// will be the position (but there is no guarantee
|
||||||
|
// anything will be stored)
|
||||||
|
curPos int32
|
||||||
|
|
||||||
|
// ** cache metadata **
|
||||||
|
|
||||||
|
// was something stored in the cache?
|
||||||
|
encoderCached bool
|
||||||
|
|
||||||
|
// position of the cached data
|
||||||
|
encoderPos int32
|
||||||
|
|
||||||
|
// ** cache data storage **
|
||||||
|
backend ml.Backend
|
||||||
|
ctxs map[int]ml.Context
|
||||||
|
keys, values map[int]ml.Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewEncoderCache() *EncoderCache {
|
||||||
|
return &EncoderCache{
|
||||||
|
ctxs: make(map[int]ml.Context),
|
||||||
|
keys: make(map[int]ml.Tensor),
|
||||||
|
values: make(map[int]ml.Tensor),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||||
|
if c.config == nil {
|
||||||
|
var config ml.CacheConfig
|
||||||
|
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||||
|
config = cc.CacheConfig()
|
||||||
|
}
|
||||||
|
c.config = &config
|
||||||
|
}
|
||||||
|
|
||||||
|
if maxSequences > 1 {
|
||||||
|
panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
|
||||||
|
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
||||||
|
}
|
||||||
|
|
||||||
|
c.backend = backend
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
|
||||||
|
if c.config != nil {
|
||||||
|
panic("config cannot be changed after being previously set, either by the model or backend")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.config = &config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *EncoderCache) Close() {
|
||||||
|
for _, ctx := range c.ctxs {
|
||||||
|
ctx.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch) error {
|
||||||
|
// We work with the most recent image
|
||||||
|
if len(batch.Multimodal) > 0 {
|
||||||
|
c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *EncoderCache) SetLayer(layer int) {
|
||||||
|
c.curLayer = layer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *EncoderCache) EncoderCached() bool {
|
||||||
|
return c.encoderCached
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||||
|
return c.keys[c.curLayer], c.values[c.curLayer], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||||
|
c.encoderPos = c.curPos
|
||||||
|
c.encoderCached = true
|
||||||
|
|
||||||
|
if c.config.PermutedV {
|
||||||
|
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||||
|
c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := c.keys[c.curLayer]; !ok {
|
||||||
|
c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := c.values[c.curLayer]; !ok {
|
||||||
|
c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Forward(
|
||||||
|
key.Copy(ctx, c.keys[c.curLayer]),
|
||||||
|
value.Copy(ctx, c.values[c.curLayer]),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||||
|
panic("encoder cache does not support multiple sequences")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||||
|
if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
|
||||||
|
c.encoderCached = false
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
100
kvcache/wrapper.go
Normal file
100
kvcache/wrapper.go
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
package kvcache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Wrapper cache is a container for multiple types of caches,
|
||||||
|
// such as for the encoding and decoding portions of a model.
|
||||||
|
type WrapperCache struct {
|
||||||
|
// caches we are wrapping
|
||||||
|
caches []Cache
|
||||||
|
|
||||||
|
// cache to be used for this layer
|
||||||
|
curType int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWrapperCache(caches ...Cache) *WrapperCache {
|
||||||
|
return &WrapperCache{
|
||||||
|
caches: caches,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||||
|
for _, cache := range c.caches {
|
||||||
|
cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
|
||||||
|
for _, cache := range c.caches {
|
||||||
|
cache.SetConfig(config)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WrapperCache) Close() {
|
||||||
|
for _, cache := range c.caches {
|
||||||
|
cache.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch) error {
|
||||||
|
for i, cache := range c.caches {
|
||||||
|
err := cache.StartForward(ctx, batch)
|
||||||
|
if err != nil {
|
||||||
|
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
|
||||||
|
for j := i - 1; j >= 0; j-- {
|
||||||
|
for k := range batch.Positions {
|
||||||
|
_ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.curType = 0
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WrapperCache) SetLayer(layer int) {
|
||||||
|
for _, cache := range c.caches {
|
||||||
|
cache.SetLayer(layer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WrapperCache) SetLayerType(layerType int) {
|
||||||
|
c.curType = layerType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WrapperCache) UnderlyingCache() Cache {
|
||||||
|
return c.caches[c.curType]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||||
|
return c.caches[c.curType].Get(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||||
|
c.caches[c.curType].Put(ctx, key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||||
|
for _, cache := range c.caches {
|
||||||
|
cache.CopyPrefix(srcSeq, dstSeq, len)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||||
|
// If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
|
||||||
|
for _, cache := range c.caches {
|
||||||
|
err := cache.Remove(seq, beginIndex, endIndex)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -8,7 +8,7 @@ Ollama vendors [llama.cpp](https://github.com/ggerganov/llama.cpp/) and [ggml](h
|
|||||||
|
|
||||||
If you update the vendoring code, start by running the following command to establish the tracking llama.cpp repo in the `./vendor/` directory.
|
If you update the vendoring code, start by running the following command to establish the tracking llama.cpp repo in the `./vendor/` directory.
|
||||||
|
|
||||||
```
|
```shell
|
||||||
make -f Makefile.sync apply-patches
|
make -f Makefile.sync apply-patches
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -22,7 +22,7 @@ When updating to a newer base commit, the existing patches may not apply cleanly
|
|||||||
|
|
||||||
Start by applying the patches. If any of the patches have conflicts, the `git am` will stop at the first failure.
|
Start by applying the patches. If any of the patches have conflicts, the `git am` will stop at the first failure.
|
||||||
|
|
||||||
```
|
```shell
|
||||||
make -f Makefile.sync apply-patches
|
make -f Makefile.sync apply-patches
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -30,7 +30,7 @@ If there are conflicts, you will see an error message. Resolve the conflicts in
|
|||||||
|
|
||||||
Once all patches are applied, commit the changes to the tracking repository.
|
Once all patches are applied, commit the changes to the tracking repository.
|
||||||
|
|
||||||
```
|
```shell
|
||||||
make -f Makefile.sync format-patches sync
|
make -f Makefile.sync format-patches sync
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -38,13 +38,13 @@ make -f Makefile.sync format-patches sync
|
|||||||
|
|
||||||
When working on new fixes or features that impact vendored code, use the following model. First get a clean tracking repo with all current patches applied:
|
When working on new fixes or features that impact vendored code, use the following model. First get a clean tracking repo with all current patches applied:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
make -f Makefile.sync clean apply-patches
|
make -f Makefile.sync clean apply-patches
|
||||||
```
|
```
|
||||||
|
|
||||||
Iterate until you're ready to submit PRs. Once your code is ready, commit a change in the `./vendor/` directory, then generate the patches for ollama with
|
Iterate until you're ready to submit PRs. Once your code is ready, commit a change in the `./vendor/` directory, then generate the patches for ollama with
|
||||||
|
|
||||||
```
|
```shell
|
||||||
make -f Makefile.sync format-patches
|
make -f Makefile.sync format-patches
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
2
llama/build-info.cpp
generated
vendored
2
llama/build-info.cpp
generated
vendored
@@ -1,4 +1,4 @@
|
|||||||
int LLAMA_BUILD_NUMBER = 0;
|
int LLAMA_BUILD_NUMBER = 0;
|
||||||
char const *LLAMA_COMMIT = "ba1cb19cdd0d92e012e0f6e009e0620f854b6afd";
|
char const *LLAMA_COMMIT = "d7cfe1ffe0f435d0048a6058d529daf76e072d9c";
|
||||||
char const *LLAMA_COMPILER = "";
|
char const *LLAMA_COMPILER = "";
|
||||||
char const *LLAMA_BUILD_TARGET = "";
|
char const *LLAMA_BUILD_TARGET = "";
|
||||||
|
|||||||
4
llama/build-info.cpp.in
Normal file
4
llama/build-info.cpp.in
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
int LLAMA_BUILD_NUMBER = 0;
|
||||||
|
char const *LLAMA_COMMIT = "@FETCH_HEAD@";
|
||||||
|
char const *LLAMA_COMPILER = "";
|
||||||
|
char const *LLAMA_BUILD_TARGET = "";
|
||||||
343
llama/llama.cpp/common/common.cpp
vendored
343
llama/llama.cpp/common/common.cpp
vendored
@@ -2,6 +2,9 @@
|
|||||||
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
|
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "gguf.h"
|
||||||
|
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "log.h"
|
#include "log.h"
|
||||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||||
@@ -70,6 +73,22 @@
|
|||||||
#include <sys/syslimits.h>
|
#include <sys/syslimits.h>
|
||||||
#endif
|
#endif
|
||||||
#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
|
#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
|
||||||
|
|
||||||
|
//
|
||||||
|
// CURL utils
|
||||||
|
//
|
||||||
|
|
||||||
|
using curl_ptr = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
|
||||||
|
|
||||||
|
// cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one
|
||||||
|
struct curl_slist_ptr {
|
||||||
|
struct curl_slist * ptr = nullptr;
|
||||||
|
~curl_slist_ptr() {
|
||||||
|
if (ptr) {
|
||||||
|
curl_slist_free_all(ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
#endif // LLAMA_USE_CURL
|
#endif // LLAMA_USE_CURL
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
@@ -464,6 +483,48 @@ void string_replace_all(std::string & s, const std::string & search, const std::
|
|||||||
s = std::move(builder);
|
s = std::move(builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string string_join(const std::vector<std::string> & values, const std::string & separator) {
|
||||||
|
std::ostringstream result;
|
||||||
|
for (size_t i = 0; i < values.size(); ++i) {
|
||||||
|
if (i > 0) {
|
||||||
|
result << separator;
|
||||||
|
}
|
||||||
|
result << values[i];
|
||||||
|
}
|
||||||
|
return result.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> string_split(const std::string & str, const std::string & delimiter) {
|
||||||
|
std::vector<std::string> parts;
|
||||||
|
size_t start = 0;
|
||||||
|
size_t end = str.find(delimiter);
|
||||||
|
|
||||||
|
while (end != std::string::npos) {
|
||||||
|
parts.push_back(str.substr(start, end - start));
|
||||||
|
start = end + delimiter.length();
|
||||||
|
end = str.find(delimiter, start);
|
||||||
|
}
|
||||||
|
|
||||||
|
parts.push_back(str.substr(start));
|
||||||
|
|
||||||
|
return parts;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string string_repeat(const std::string & str, size_t n) {
|
||||||
|
if (n == 0) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string result;
|
||||||
|
result.reserve(str.length() * n);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < n; ++i) {
|
||||||
|
result += str;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
std::string string_from(bool value) {
|
std::string string_from(bool value) {
|
||||||
return value ? "true" : "false";
|
return value ? "true" : "false";
|
||||||
}
|
}
|
||||||
@@ -846,7 +907,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
} else if (!params.model_url.empty()) {
|
} else if (!params.model_url.empty()) {
|
||||||
model = common_load_model_from_url(params.model_url, params.model, params.hf_token, mparams);
|
model = common_load_model_from_url(params.model_url, params.model, params.hf_token, mparams);
|
||||||
} else {
|
} else {
|
||||||
model = llama_load_model_from_file(params.model.c_str(), mparams);
|
model = llama_model_load_from_file(params.model.c_str(), mparams);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (model == NULL) {
|
if (model == NULL) {
|
||||||
@@ -854,26 +915,28 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
return iparams;
|
return iparams;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||||
|
|
||||||
if (params.reranking) {
|
if (params.reranking) {
|
||||||
bool ok = true;
|
bool ok = true;
|
||||||
|
|
||||||
if (llama_token_bos(model) == LLAMA_TOKEN_NULL) {
|
if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
|
||||||
LOG_WRN("%s: warning: model does not have a BOS token, reranking will not work\n", __func__);
|
LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
|
||||||
ok = false;
|
ok = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_token_eos(model) == LLAMA_TOKEN_NULL) {
|
if (llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
|
||||||
LOG_WRN("%s: warning: model does not have an EOS token, reranking will not work\n", __func__);
|
LOG_WRN("%s: warning: vocab does not have an EOS token, reranking will not work\n", __func__);
|
||||||
ok = false;
|
ok = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_token_sep(model) == LLAMA_TOKEN_NULL) {
|
if (llama_vocab_sep(vocab) == LLAMA_TOKEN_NULL) {
|
||||||
LOG_WRN("%s: warning: model does not have a SEP token, reranking will not work\n", __func__);
|
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
|
||||||
ok = false;
|
ok = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!ok) {
|
if (!ok) {
|
||||||
llama_free_model(model);
|
llama_model_free(model);
|
||||||
|
|
||||||
return iparams;
|
return iparams;
|
||||||
}
|
}
|
||||||
@@ -881,10 +944,10 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
|
|
||||||
auto cparams = common_context_params_to_llama(params);
|
auto cparams = common_context_params_to_llama(params);
|
||||||
|
|
||||||
llama_context * lctx = llama_new_context_with_model(model, cparams);
|
llama_context * lctx = llama_init_from_model(model, cparams);
|
||||||
if (lctx == NULL) {
|
if (lctx == NULL) {
|
||||||
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str());
|
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str());
|
||||||
llama_free_model(model);
|
llama_model_free(model);
|
||||||
return iparams;
|
return iparams;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -895,17 +958,18 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
|
|
||||||
if (!params.control_vectors.empty()) {
|
if (!params.control_vectors.empty()) {
|
||||||
if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1;
|
if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1;
|
||||||
if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model);
|
if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_model_n_layer(model);
|
||||||
|
|
||||||
const auto cvec = common_control_vector_load(params.control_vectors);
|
const auto cvec = common_control_vector_load(params.control_vectors);
|
||||||
if (cvec.n_embd == -1) {
|
if (cvec.n_embd == -1) {
|
||||||
llama_free(lctx);
|
llama_free(lctx);
|
||||||
llama_free_model(model);
|
llama_model_free(model);
|
||||||
|
|
||||||
return iparams;
|
return iparams;
|
||||||
}
|
}
|
||||||
|
|
||||||
int err = llama_control_vector_apply(lctx,
|
int err = llama_apply_adapter_cvec(
|
||||||
|
lctx,
|
||||||
cvec.data.data(),
|
cvec.data.data(),
|
||||||
cvec.data.size(),
|
cvec.data.size(),
|
||||||
cvec.n_embd,
|
cvec.n_embd,
|
||||||
@@ -913,7 +977,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
params.control_vector_layer_end);
|
params.control_vector_layer_end);
|
||||||
if (err) {
|
if (err) {
|
||||||
llama_free(lctx);
|
llama_free(lctx);
|
||||||
llama_free_model(model);
|
llama_model_free(model);
|
||||||
|
|
||||||
return iparams;
|
return iparams;
|
||||||
}
|
}
|
||||||
@@ -921,12 +985,12 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
|
|
||||||
// load and optionally apply lora adapters
|
// load and optionally apply lora adapters
|
||||||
for (auto & la : params.lora_adapters) {
|
for (auto & la : params.lora_adapters) {
|
||||||
llama_lora_adapter_ptr lora;
|
llama_adapter_lora_ptr lora;
|
||||||
lora.reset(llama_lora_adapter_init(model, la.path.c_str()));
|
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
|
||||||
if (lora == nullptr) {
|
if (lora == nullptr) {
|
||||||
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
|
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
|
||||||
llama_free(lctx);
|
llama_free(lctx);
|
||||||
llama_free_model(model);
|
llama_model_free(model);
|
||||||
return iparams;
|
return iparams;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -935,17 +999,17 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!params.lora_init_without_apply) {
|
if (!params.lora_init_without_apply) {
|
||||||
common_lora_adapters_apply(lctx, params.lora_adapters);
|
common_set_adapter_lora(lctx, params.lora_adapters);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) {
|
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
|
||||||
LOG_WRN("%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__);
|
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
|
||||||
params.sampling.ignore_eos = false;
|
params.sampling.ignore_eos = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.sampling.ignore_eos) {
|
if (params.sampling.ignore_eos) {
|
||||||
for (llama_token i = 0; i < llama_n_vocab(model); i++) {
|
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
|
||||||
if (llama_token_is_eog(model, i)) {
|
if (llama_vocab_is_eog(vocab, i)) {
|
||||||
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
|
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
|
||||||
params.sampling.logit_bias.push_back({i, -INFINITY});
|
params.sampling.logit_bias.push_back({i, -INFINITY});
|
||||||
}
|
}
|
||||||
@@ -966,8 +1030,9 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
|
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
|
||||||
|
|
||||||
std::vector<llama_token> tmp;
|
std::vector<llama_token> tmp;
|
||||||
llama_token bos = llama_token_bos(model);
|
llama_token bos = llama_vocab_bos(vocab);
|
||||||
llama_token eos = llama_token_eos(model);
|
llama_token eos = llama_vocab_eos(vocab);
|
||||||
|
|
||||||
// some models (e.g. T5) don't have a BOS token
|
// some models (e.g. T5) don't have a BOS token
|
||||||
if (bos != LLAMA_TOKEN_NULL) {
|
if (bos != LLAMA_TOKEN_NULL) {
|
||||||
tmp.push_back(bos);
|
tmp.push_back(bos);
|
||||||
@@ -982,7 +1047,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
if (llama_model_has_encoder(model)) {
|
if (llama_model_has_encoder(model)) {
|
||||||
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size()));
|
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size()));
|
||||||
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
|
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
|
||||||
if (decoder_start_token_id == -1) {
|
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
|
||||||
decoder_start_token_id = bos;
|
decoder_start_token_id = bos;
|
||||||
}
|
}
|
||||||
tmp.clear();
|
tmp.clear();
|
||||||
@@ -1002,11 +1067,11 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
return iparams;
|
return iparams;
|
||||||
}
|
}
|
||||||
|
|
||||||
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_info> & lora) {
|
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
|
||||||
llama_lora_adapter_clear(ctx);
|
llama_clear_adapter_lora(ctx);
|
||||||
for (auto & la : lora) {
|
for (auto & la : lora) {
|
||||||
if (la.scale != 0.0f) {
|
if (la.scale != 0.0f) {
|
||||||
llama_lora_adapter_set(ctx, la.ptr, la.scale);
|
llama_set_adapter_lora(ctx, la.ptr, la.scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1020,7 +1085,6 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
|
|||||||
if (params.n_gpu_layers != -1) {
|
if (params.n_gpu_layers != -1) {
|
||||||
mparams.n_gpu_layers = params.n_gpu_layers;
|
mparams.n_gpu_layers = params.n_gpu_layers;
|
||||||
}
|
}
|
||||||
mparams.rpc_servers = params.rpc_servers.c_str();
|
|
||||||
mparams.main_gpu = params.main_gpu;
|
mparams.main_gpu = params.main_gpu;
|
||||||
mparams.split_mode = params.split_mode;
|
mparams.split_mode = params.split_mode;
|
||||||
mparams.tensor_split = params.tensor_split;
|
mparams.tensor_split = params.tensor_split;
|
||||||
@@ -1123,7 +1187,8 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma
|
|||||||
|
|
||||||
static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
|
static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
|
||||||
// Initialize libcurl
|
// Initialize libcurl
|
||||||
std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
|
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
|
||||||
|
curl_slist_ptr http_headers;
|
||||||
if (!curl) {
|
if (!curl) {
|
||||||
LOG_ERR("%s: error initializing libcurl\n", __func__);
|
LOG_ERR("%s: error initializing libcurl\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
@@ -1137,11 +1202,9 @@ static bool common_download_file(const std::string & url, const std::string & pa
|
|||||||
|
|
||||||
// Check if hf-token or bearer-token was specified
|
// Check if hf-token or bearer-token was specified
|
||||||
if (!hf_token.empty()) {
|
if (!hf_token.empty()) {
|
||||||
std::string auth_header = "Authorization: Bearer ";
|
std::string auth_header = "Authorization: Bearer " + hf_token;
|
||||||
auth_header += hf_token.c_str();
|
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
|
||||||
struct curl_slist *http_headers = NULL;
|
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
|
||||||
http_headers = curl_slist_append(http_headers, auth_header.c_str());
|
|
||||||
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
@@ -1411,7 +1474,7 @@ struct llama_model * common_load_model_from_url(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return llama_load_model_from_file(local_path.c_str(), params);
|
return llama_model_load_from_file(local_path.c_str(), params);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_model * common_load_model_from_hf(
|
struct llama_model * common_load_model_from_hf(
|
||||||
@@ -1437,6 +1500,80 @@ struct llama_model * common_load_model_from_hf(
|
|||||||
return common_load_model_from_url(model_url, local_path, hf_token, params);
|
return common_load_model_from_url(model_url, local_path, hf_token, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Allow getting the HF file from the HF repo with tag (like ollama), for example:
|
||||||
|
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
|
||||||
|
* - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M
|
||||||
|
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s
|
||||||
|
* Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo)
|
||||||
|
*
|
||||||
|
* Return pair of <repo, file> (with "repo" already having tag removed)
|
||||||
|
*
|
||||||
|
* Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
|
||||||
|
*/
|
||||||
|
std::pair<std::string, std::string> common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & hf_token) {
|
||||||
|
auto parts = string_split<std::string>(hf_repo_with_tag, ':');
|
||||||
|
std::string tag = parts.size() > 1 ? parts.back() : "latest";
|
||||||
|
std::string hf_repo = parts[0];
|
||||||
|
if (string_split<std::string>(hf_repo, '/').size() != 2) {
|
||||||
|
throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetch model info from Hugging Face Hub API
|
||||||
|
json model_info;
|
||||||
|
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
|
||||||
|
curl_slist_ptr http_headers;
|
||||||
|
std::string res_str;
|
||||||
|
std::string url = "https://huggingface.co/v2/" + hf_repo + "/manifests/" + tag;
|
||||||
|
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
|
||||||
|
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
|
||||||
|
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
|
||||||
|
auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
|
||||||
|
static_cast<std::string *>(data)->append((char * ) ptr, size * nmemb);
|
||||||
|
return size * nmemb;
|
||||||
|
};
|
||||||
|
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
|
||||||
|
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str);
|
||||||
|
#if defined(_WIN32)
|
||||||
|
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
|
||||||
|
#endif
|
||||||
|
if (!hf_token.empty()) {
|
||||||
|
std::string auth_header = "Authorization: Bearer " + hf_token;
|
||||||
|
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
|
||||||
|
}
|
||||||
|
// Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
|
||||||
|
http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
|
||||||
|
http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json");
|
||||||
|
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
|
||||||
|
|
||||||
|
CURLcode res = curl_easy_perform(curl.get());
|
||||||
|
|
||||||
|
if (res != CURLE_OK) {
|
||||||
|
throw std::runtime_error("error: cannot make GET request to HF API");
|
||||||
|
}
|
||||||
|
|
||||||
|
long res_code;
|
||||||
|
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code);
|
||||||
|
if (res_code == 200) {
|
||||||
|
model_info = json::parse(res_str);
|
||||||
|
} else if (res_code == 401) {
|
||||||
|
throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// check response
|
||||||
|
if (!model_info.contains("ggufFile")) {
|
||||||
|
throw std::runtime_error("error: model does not have ggufFile");
|
||||||
|
}
|
||||||
|
json & gguf_file = model_info.at("ggufFile");
|
||||||
|
if (!gguf_file.contains("rfilename")) {
|
||||||
|
throw std::runtime_error("error: ggufFile does not have rfilename");
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_pair(hf_repo, gguf_file.at("rfilename"));
|
||||||
|
}
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
struct llama_model * common_load_model_from_url(
|
struct llama_model * common_load_model_from_url(
|
||||||
@@ -1458,6 +1595,11 @@ struct llama_model * common_load_model_from_hf(
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<std::string, std::string> common_get_hf_file(const std::string &, const std::string &) {
|
||||||
|
LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
|
||||||
|
return std::make_pair("", "");
|
||||||
|
}
|
||||||
|
|
||||||
#endif // LLAMA_USE_CURL
|
#endif // LLAMA_USE_CURL
|
||||||
|
|
||||||
//
|
//
|
||||||
@@ -1556,21 +1698,23 @@ std::vector<llama_token> common_tokenize(
|
|||||||
const std::string & text,
|
const std::string & text,
|
||||||
bool add_special,
|
bool add_special,
|
||||||
bool parse_special) {
|
bool parse_special) {
|
||||||
return common_tokenize(llama_get_model(ctx), text, add_special, parse_special);
|
const llama_model * model = llama_get_model(ctx);
|
||||||
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||||
|
return common_tokenize(vocab, text, add_special, parse_special);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llama_token> common_tokenize(
|
std::vector<llama_token> common_tokenize(
|
||||||
const struct llama_model * model,
|
const struct llama_vocab * vocab,
|
||||||
const std::string & text,
|
const std::string & text,
|
||||||
bool add_special,
|
bool add_special,
|
||||||
bool parse_special) {
|
bool parse_special) {
|
||||||
// upper limit for the number of tokens
|
// upper limit for the number of tokens
|
||||||
int n_tokens = text.length() + 2 * add_special;
|
int n_tokens = text.length() + 2 * add_special;
|
||||||
std::vector<llama_token> result(n_tokens);
|
std::vector<llama_token> result(n_tokens);
|
||||||
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
|
n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
|
||||||
if (n_tokens < 0) {
|
if (n_tokens < 0) {
|
||||||
result.resize(-n_tokens);
|
result.resize(-n_tokens);
|
||||||
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
|
int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
|
||||||
GGML_ASSERT(check == -n_tokens);
|
GGML_ASSERT(check == -n_tokens);
|
||||||
} else {
|
} else {
|
||||||
result.resize(n_tokens);
|
result.resize(n_tokens);
|
||||||
@@ -1579,12 +1723,18 @@ std::vector<llama_token> common_tokenize(
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
|
std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
|
||||||
|
const llama_model * model = llama_get_model(ctx);
|
||||||
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||||
|
return common_token_to_piece(vocab, token, special);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string common_token_to_piece(const struct llama_vocab * vocab, llama_token token, bool special) {
|
||||||
std::string piece;
|
std::string piece;
|
||||||
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
|
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
|
||||||
const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special);
|
const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
|
||||||
if (n_chars < 0) {
|
if (n_chars < 0) {
|
||||||
piece.resize(-n_chars);
|
piece.resize(-n_chars);
|
||||||
int check = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special);
|
int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
|
||||||
GGML_ASSERT(check == -n_chars);
|
GGML_ASSERT(check == -n_chars);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
@@ -1594,13 +1744,19 @@ std::string common_token_to_piece(const struct llama_context * ctx, llama_token
|
|||||||
return piece;
|
return piece;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string common_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
|
std::string common_detokenize(const struct llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
|
||||||
|
const llama_model * model = llama_get_model(ctx);
|
||||||
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||||
|
return common_detokenize(vocab, tokens, special);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string common_detokenize(const struct llama_vocab * vocab, const std::vector<llama_token> & tokens, bool special) {
|
||||||
std::string text;
|
std::string text;
|
||||||
text.resize(std::max(text.capacity(), tokens.size()));
|
text.resize(std::max(text.capacity(), tokens.size()));
|
||||||
int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
|
int32_t n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
|
||||||
if (n_chars < 0) {
|
if (n_chars < 0) {
|
||||||
text.resize(-n_chars);
|
text.resize(-n_chars);
|
||||||
n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
|
n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
|
||||||
GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
|
GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1610,103 +1766,6 @@ std::string common_detokenize(llama_context * ctx, const std::vector<llama_token
|
|||||||
return text;
|
return text;
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
|
||||||
// Chat template utils
|
|
||||||
//
|
|
||||||
|
|
||||||
std::string common_get_builtin_chat_template(const struct llama_model * model) {
|
|
||||||
static const char * template_key = "tokenizer.chat_template";
|
|
||||||
// call with NULL buffer to get the total size of the string
|
|
||||||
int32_t res = llama_model_meta_val_str(model, template_key, NULL, 0);
|
|
||||||
if (res > 0) {
|
|
||||||
std::vector<char> model_template(res + 1, 0);
|
|
||||||
llama_model_meta_val_str(model, template_key, model_template.data(), model_template.size());
|
|
||||||
return std::string(model_template.data(), model_template.size() - 1);
|
|
||||||
}
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
|
|
||||||
bool common_chat_verify_template(const std::string & tmpl) {
|
|
||||||
llama_chat_message chat[] = {{"user", "test"}};
|
|
||||||
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
|
|
||||||
return res >= 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string common_chat_apply_template(const struct llama_model * model,
|
|
||||||
const std::string & tmpl,
|
|
||||||
const std::vector<common_chat_msg> & msgs,
|
|
||||||
bool add_ass) {
|
|
||||||
int alloc_size = 0;
|
|
||||||
bool fallback = false; // indicate if we must fallback to default chatml
|
|
||||||
std::vector<llama_chat_message> chat;
|
|
||||||
for (auto & msg : msgs) {
|
|
||||||
chat.push_back({msg.role.c_str(), msg.content.c_str()});
|
|
||||||
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
|
|
||||||
}
|
|
||||||
|
|
||||||
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
|
|
||||||
std::vector<char> buf(alloc_size);
|
|
||||||
|
|
||||||
// run the first time to get the total output length
|
|
||||||
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
|
||||||
|
|
||||||
// error: chat template is not supported
|
|
||||||
if (res < 0) {
|
|
||||||
if (ptr_tmpl != nullptr) {
|
|
||||||
// if the custom "tmpl" is not supported, we throw an error
|
|
||||||
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
|
|
||||||
throw std::runtime_error("this custom template is not supported");
|
|
||||||
} else {
|
|
||||||
// If the built-in template is not supported, we default to chatml
|
|
||||||
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
|
||||||
fallback = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// if it turns out that our buffer is too small, we resize it
|
|
||||||
if ((size_t) res > buf.size()) {
|
|
||||||
buf.resize(res);
|
|
||||||
res = llama_chat_apply_template(
|
|
||||||
fallback ? nullptr : model,
|
|
||||||
fallback ? "chatml" : ptr_tmpl,
|
|
||||||
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string formatted_chat(buf.data(), res);
|
|
||||||
return formatted_chat;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string common_chat_format_single(const struct llama_model * model,
|
|
||||||
const std::string & tmpl,
|
|
||||||
const std::vector<common_chat_msg> & past_msg,
|
|
||||||
const common_chat_msg & new_msg,
|
|
||||||
bool add_ass) {
|
|
||||||
std::ostringstream ss;
|
|
||||||
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, past_msg, false);
|
|
||||||
std::vector<common_chat_msg> chat_new(past_msg);
|
|
||||||
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
|
||||||
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
|
||||||
ss << "\n";
|
|
||||||
};
|
|
||||||
// format chat with new_msg
|
|
||||||
chat_new.push_back(new_msg);
|
|
||||||
auto fmt_new_msg = common_chat_apply_template(model, tmpl, chat_new, add_ass);
|
|
||||||
// get the diff part
|
|
||||||
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string common_chat_format_example(const struct llama_model * model,
|
|
||||||
const std::string & tmpl) {
|
|
||||||
std::vector<common_chat_msg> msgs = {
|
|
||||||
{"system", "You are a helpful assistant"},
|
|
||||||
{"user", "Hello"},
|
|
||||||
{"assistant", "Hi there"},
|
|
||||||
{"user", "How are you?"},
|
|
||||||
};
|
|
||||||
return common_chat_apply_template(model, tmpl, msgs, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// KV cache utils
|
// KV cache utils
|
||||||
//
|
//
|
||||||
|
|||||||
114
llama/llama.cpp/common/common.h
vendored
114
llama/llama.cpp/common/common.h
vendored
@@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#include "llama-cpp.h"
|
#include "llama-cpp.h"
|
||||||
|
|
||||||
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
@@ -24,11 +25,11 @@
|
|||||||
|
|
||||||
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
|
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
|
||||||
|
|
||||||
struct common_lora_adapter_info {
|
struct common_adapter_lora_info {
|
||||||
std::string path;
|
std::string path;
|
||||||
float scale;
|
float scale;
|
||||||
|
|
||||||
struct llama_lora_adapter * ptr;
|
struct llama_adapter_lora * ptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
using llama_tokens = std::vector<llama_token>;
|
using llama_tokens = std::vector<llama_token>;
|
||||||
@@ -103,6 +104,17 @@ enum dimre_method {
|
|||||||
DIMRE_METHOD_MEAN,
|
DIMRE_METHOD_MEAN,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum common_conversation_mode {
|
||||||
|
COMMON_CONVERSATION_MODE_DISABLED = 0,
|
||||||
|
COMMON_CONVERSATION_MODE_ENABLED = 1,
|
||||||
|
COMMON_CONVERSATION_MODE_AUTO = 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_grammar_trigger {
|
||||||
|
std::string word;
|
||||||
|
bool at_start;
|
||||||
|
};
|
||||||
|
|
||||||
// sampling parameters
|
// sampling parameters
|
||||||
struct common_params_sampling {
|
struct common_params_sampling {
|
||||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
||||||
@@ -128,6 +140,7 @@ struct common_params_sampling {
|
|||||||
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
|
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
|
||||||
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
||||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||||
|
float top_n_sigma = -1.00f;// -1.0 = disabled
|
||||||
float mirostat_tau = 5.00f; // target entropy
|
float mirostat_tau = 5.00f; // target entropy
|
||||||
float mirostat_eta = 0.10f; // learning rate
|
float mirostat_eta = 0.10f; // learning rate
|
||||||
bool ignore_eos = false;
|
bool ignore_eos = false;
|
||||||
@@ -149,6 +162,10 @@ struct common_params_sampling {
|
|||||||
};
|
};
|
||||||
|
|
||||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||||
|
bool grammar_lazy = false;
|
||||||
|
std::vector<common_grammar_trigger> grammar_trigger_words; // optional trigger words to trigger lazy grammar
|
||||||
|
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens.
|
||||||
|
std::set<llama_token> preserved_tokens;
|
||||||
|
|
||||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||||
|
|
||||||
@@ -161,15 +178,19 @@ struct common_params_speculative {
|
|||||||
|
|
||||||
int32_t n_ctx = 0; // draft context size
|
int32_t n_ctx = 0; // draft context size
|
||||||
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
|
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
|
||||||
int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding
|
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
|
||||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||||
float p_split = 0.1f; // speculative decoding split probability
|
float p_split = 0.1f; // speculative decoding split probability
|
||||||
float p_min = 0.9f; // minimum speculative decoding probability (greedy)
|
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
||||||
|
|
||||||
struct cpu_params cpuparams;
|
struct cpu_params cpuparams;
|
||||||
struct cpu_params cpuparams_batch;
|
struct cpu_params cpuparams_batch;
|
||||||
|
|
||||||
|
std::string hf_repo = ""; // HF repo // NOLINT
|
||||||
|
std::string hf_file = ""; // HF file // NOLINT
|
||||||
|
|
||||||
std::string model = ""; // draft model for speculative decoding // NOLINT
|
std::string model = ""; // draft model for speculative decoding // NOLINT
|
||||||
|
std::string model_url = ""; // model url to download // NOLINT
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_params_vocoder {
|
struct common_params_vocoder {
|
||||||
@@ -178,6 +199,13 @@ struct common_params_vocoder {
|
|||||||
|
|
||||||
std::string model = ""; // model path // NOLINT
|
std::string model = ""; // model path // NOLINT
|
||||||
std::string model_url = ""; // model url to download // NOLINT
|
std::string model_url = ""; // model url to download // NOLINT
|
||||||
|
|
||||||
|
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
|
||||||
|
};
|
||||||
|
|
||||||
|
enum common_reasoning_format {
|
||||||
|
COMMON_REASONING_FORMAT_NONE,
|
||||||
|
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_params {
|
struct common_params {
|
||||||
@@ -240,14 +268,13 @@ struct common_params {
|
|||||||
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
|
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
|
||||||
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
|
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
|
||||||
std::string logits_file = ""; // file for saving *all* logits // NOLINT
|
std::string logits_file = ""; // file for saving *all* logits // NOLINT
|
||||||
std::string rpc_servers = ""; // comma separated list of RPC servers // NOLINT
|
|
||||||
|
|
||||||
std::vector<std::string> in_files; // all input files
|
std::vector<std::string> in_files; // all input files
|
||||||
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
|
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
|
||||||
std::vector<llama_model_kv_override> kv_overrides;
|
std::vector<llama_model_kv_override> kv_overrides;
|
||||||
|
|
||||||
bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply)
|
bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_adapter_lora_apply)
|
||||||
std::vector<common_lora_adapter_info> lora_adapters; // lora adapter path with user defined scale
|
std::vector<common_adapter_lora_info> lora_adapters; // lora adapter path with user defined scale
|
||||||
|
|
||||||
std::vector<common_control_vector_load_info> control_vectors; // control vector with user defined scale
|
std::vector<common_control_vector_load_info> control_vectors; // control vector with user defined scale
|
||||||
|
|
||||||
@@ -271,11 +298,11 @@ struct common_params {
|
|||||||
bool kl_divergence = false; // compute KL divergence
|
bool kl_divergence = false; // compute KL divergence
|
||||||
|
|
||||||
bool usage = false; // print usage
|
bool usage = false; // print usage
|
||||||
|
bool completion = false; // print source-able completion script
|
||||||
bool use_color = false; // use color to distinguish generations and inputs
|
bool use_color = false; // use color to distinguish generations and inputs
|
||||||
bool special = false; // enable special token output
|
bool special = false; // enable special token output
|
||||||
bool interactive = false; // interactive mode
|
bool interactive = false; // interactive mode
|
||||||
bool interactive_first = false; // wait for user input immediately
|
bool interactive_first = false; // wait for user input immediately
|
||||||
bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix)
|
|
||||||
bool prompt_cache_all = false; // save user input and generations to prompt cache
|
bool prompt_cache_all = false; // save user input and generations to prompt cache
|
||||||
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it
|
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it
|
||||||
|
|
||||||
@@ -301,6 +328,8 @@ struct common_params {
|
|||||||
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
|
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
|
||||||
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
|
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
|
||||||
|
|
||||||
|
common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO;
|
||||||
|
|
||||||
// multimodal models (see examples/llava)
|
// multimodal models (see examples/llava)
|
||||||
std::string mmproj = ""; // path to multimodal projector // NOLINT
|
std::string mmproj = ""; // path to multimodal projector // NOLINT
|
||||||
std::vector<std::string> image; // path to image file(s)
|
std::vector<std::string> image; // path to image file(s)
|
||||||
@@ -322,7 +351,9 @@ struct common_params {
|
|||||||
std::string hostname = "127.0.0.1";
|
std::string hostname = "127.0.0.1";
|
||||||
std::string public_path = ""; // NOLINT
|
std::string public_path = ""; // NOLINT
|
||||||
std::string chat_template = ""; // NOLINT
|
std::string chat_template = ""; // NOLINT
|
||||||
|
bool use_jinja = false; // NOLINT
|
||||||
bool enable_chat_template = true;
|
bool enable_chat_template = true;
|
||||||
|
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||||
|
|
||||||
std::vector<std::string> api_keys;
|
std::vector<std::string> api_keys;
|
||||||
|
|
||||||
@@ -401,7 +432,7 @@ bool set_process_priority(enum ggml_sched_priority prio);
|
|||||||
//
|
//
|
||||||
|
|
||||||
#ifdef __GNUC__
|
#ifdef __GNUC__
|
||||||
#ifdef __MINGW32__
|
# if defined(__MINGW32__) && !defined(__clang__)
|
||||||
# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
||||||
# else
|
# else
|
||||||
# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
||||||
@@ -416,6 +447,10 @@ std::string string_format(const char * fmt, ...);
|
|||||||
std::string string_strip(const std::string & str);
|
std::string string_strip(const std::string & str);
|
||||||
std::string string_get_sortable_timestamp();
|
std::string string_get_sortable_timestamp();
|
||||||
|
|
||||||
|
std::string string_join(const std::vector<std::string> & values, const std::string & separator);
|
||||||
|
std::vector<std::string> string_split(const std::string & str, const std::string & delimiter);
|
||||||
|
std::string string_repeat(const std::string & str, size_t n);
|
||||||
|
|
||||||
void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
|
void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
|
||||||
|
|
||||||
template<class T>
|
template<class T>
|
||||||
@@ -454,6 +489,11 @@ static bool string_starts_with(const std::string & str,
|
|||||||
return str.rfind(prefix, 0) == 0;
|
return str.rfind(prefix, 0) == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool string_ends_with(const std::string & str,
|
||||||
|
const std::string & suffix) { // While we wait for C++20's std::string::ends_with...
|
||||||
|
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
|
||||||
|
}
|
||||||
|
|
||||||
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
||||||
void string_process_escapes(std::string & input);
|
void string_process_escapes(std::string & input);
|
||||||
|
|
||||||
@@ -481,7 +521,7 @@ struct common_init_result {
|
|||||||
llama_model_ptr model;
|
llama_model_ptr model;
|
||||||
llama_context_ptr context;
|
llama_context_ptr context;
|
||||||
|
|
||||||
std::vector<llama_lora_adapter_ptr> lora;
|
std::vector<llama_adapter_lora_ptr> lora;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_init_result common_init_from_params(common_params & params);
|
struct common_init_result common_init_from_params(common_params & params);
|
||||||
@@ -495,6 +535,7 @@ struct llama_model * common_load_model_from_url(
|
|||||||
const std::string & local_path,
|
const std::string & local_path,
|
||||||
const std::string & hf_token,
|
const std::string & hf_token,
|
||||||
const struct llama_model_params & params);
|
const struct llama_model_params & params);
|
||||||
|
|
||||||
struct llama_model * common_load_model_from_hf(
|
struct llama_model * common_load_model_from_hf(
|
||||||
const std::string & repo,
|
const std::string & repo,
|
||||||
const std::string & remote_path,
|
const std::string & remote_path,
|
||||||
@@ -502,8 +543,12 @@ struct llama_model * common_load_model_from_hf(
|
|||||||
const std::string & hf_token,
|
const std::string & hf_token,
|
||||||
const struct llama_model_params & params);
|
const struct llama_model_params & params);
|
||||||
|
|
||||||
|
std::pair<std::string, std::string> common_get_hf_file(
|
||||||
|
const std::string & hf_repo_with_tag,
|
||||||
|
const std::string & hf_token);
|
||||||
|
|
||||||
// clear LoRA adapters from context, then apply new list of adapters
|
// clear LoRA adapters from context, then apply new list of adapters
|
||||||
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_info> & lora);
|
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Batch utils
|
// Batch utils
|
||||||
@@ -541,7 +586,7 @@ std::vector<llama_token> common_tokenize(
|
|||||||
bool parse_special = false);
|
bool parse_special = false);
|
||||||
|
|
||||||
std::vector<llama_token> common_tokenize(
|
std::vector<llama_token> common_tokenize(
|
||||||
const struct llama_model * model,
|
const struct llama_vocab * vocab,
|
||||||
const std::string & text,
|
const std::string & text,
|
||||||
bool add_special,
|
bool add_special,
|
||||||
bool parse_special = false);
|
bool parse_special = false);
|
||||||
@@ -553,48 +598,23 @@ std::string common_token_to_piece(
|
|||||||
llama_token token,
|
llama_token token,
|
||||||
bool special = true);
|
bool special = true);
|
||||||
|
|
||||||
|
std::string common_token_to_piece(
|
||||||
|
const struct llama_vocab * vocab,
|
||||||
|
llama_token token,
|
||||||
|
bool special = true);
|
||||||
|
|
||||||
// detokenizes a vector of tokens into a string
|
// detokenizes a vector of tokens into a string
|
||||||
// should work similar to Python's `tokenizer.decode`
|
// should work similar to Python's `tokenizer.decode`
|
||||||
// optionally renders special/control tokens
|
// optionally renders special/control tokens
|
||||||
std::string common_detokenize(
|
std::string common_detokenize(
|
||||||
llama_context * ctx,
|
const struct llama_context * ctx,
|
||||||
const std::vector<llama_token> & tokens,
|
const std::vector<llama_token> & tokens,
|
||||||
bool special = true);
|
bool special = true);
|
||||||
|
|
||||||
//
|
std::string common_detokenize(
|
||||||
// Chat template utils
|
const struct llama_vocab * vocab,
|
||||||
//
|
const std::vector<llama_token> & tokens,
|
||||||
|
bool special = true);
|
||||||
// same with llama_chat_message, but uses std::string
|
|
||||||
struct common_chat_msg {
|
|
||||||
std::string role;
|
|
||||||
std::string content;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Get the built-in chat template for the model. Return empty string if not present.
|
|
||||||
std::string common_get_builtin_chat_template(const struct llama_model * model);
|
|
||||||
|
|
||||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
|
||||||
bool common_chat_verify_template(const std::string & tmpl);
|
|
||||||
|
|
||||||
// CPP wrapper for llama_chat_apply_template
|
|
||||||
// If the built-in template is not supported, we default to chatml
|
|
||||||
// If the custom "tmpl" is not supported, we throw an error
|
|
||||||
std::string common_chat_apply_template(const struct llama_model * model,
|
|
||||||
const std::string & tmpl,
|
|
||||||
const std::vector<common_chat_msg> & chat,
|
|
||||||
bool add_ass);
|
|
||||||
|
|
||||||
// Format single message, while taking into account the position of that message in chat history
|
|
||||||
std::string common_chat_format_single(const struct llama_model * model,
|
|
||||||
const std::string & tmpl,
|
|
||||||
const std::vector<common_chat_msg> & past_msg,
|
|
||||||
const common_chat_msg & new_msg,
|
|
||||||
bool add_ass);
|
|
||||||
|
|
||||||
// Returns an example of formatted chat
|
|
||||||
std::string common_chat_format_example(const struct llama_model * model,
|
|
||||||
const std::string & tmpl);
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// KV cache utils
|
// KV cache utils
|
||||||
|
|||||||
110
llama/llama.cpp/common/json-schema-to-grammar.cpp
vendored
110
llama/llama.cpp/common/json-schema-to-grammar.cpp
vendored
@@ -1,4 +1,6 @@
|
|||||||
#include "json-schema-to-grammar.h"
|
#include "json-schema-to-grammar.h"
|
||||||
|
#include "common.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <map>
|
#include <map>
|
||||||
@@ -11,11 +13,6 @@
|
|||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
template <typename Iterator>
|
|
||||||
static std::string join(Iterator begin, Iterator end, const std::string & separator);
|
|
||||||
|
|
||||||
static std::string repeat(const std::string & str, size_t n);
|
|
||||||
|
|
||||||
static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
|
static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
|
||||||
auto has_max = max_items != std::numeric_limits<int>::max();
|
auto has_max = max_items != std::numeric_limits<int>::max();
|
||||||
|
|
||||||
@@ -128,8 +125,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
|||||||
if (sub_len > 0) {
|
if (sub_len > 0) {
|
||||||
auto from_sub = from.substr(i + 1);
|
auto from_sub = from.substr(i + 1);
|
||||||
auto to_sub = to.substr(i + 1);
|
auto to_sub = to.substr(i + 1);
|
||||||
auto sub_zeros = repeat("0", sub_len);
|
auto sub_zeros = string_repeat("0", sub_len);
|
||||||
auto sub_nines = repeat("9", sub_len);
|
auto sub_nines = string_repeat("9", sub_len);
|
||||||
|
|
||||||
auto to_reached = false;
|
auto to_reached = false;
|
||||||
out << "(";
|
out << "(";
|
||||||
@@ -188,8 +185,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
|||||||
auto max_digits = max_s.length();
|
auto max_digits = max_s.length();
|
||||||
|
|
||||||
for (auto digits = min_digits; digits < max_digits; digits++) {
|
for (auto digits = min_digits; digits < max_digits; digits++) {
|
||||||
uniform_range(min_s, repeat("9", digits));
|
uniform_range(min_s, string_repeat("9", digits));
|
||||||
min_s = "1" + repeat("0", digits);
|
min_s = "1" + string_repeat("0", digits);
|
||||||
out << " | ";
|
out << " | ";
|
||||||
}
|
}
|
||||||
uniform_range(min_s, max_s);
|
uniform_range(min_s, max_s);
|
||||||
@@ -318,49 +315,6 @@ std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
|
|||||||
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
|
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
|
||||||
std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
|
std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
|
||||||
|
|
||||||
template <typename Iterator>
|
|
||||||
std::string join(Iterator begin, Iterator end, const std::string & separator) {
|
|
||||||
std::ostringstream result;
|
|
||||||
if (begin != end) {
|
|
||||||
result << *begin;
|
|
||||||
for (Iterator it = begin + 1; it != end; ++it) {
|
|
||||||
result << separator << *it;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::vector<std::string> split(const std::string & str, const std::string & delimiter) {
|
|
||||||
std::vector<std::string> tokens;
|
|
||||||
size_t start = 0;
|
|
||||||
size_t end = str.find(delimiter);
|
|
||||||
|
|
||||||
while (end != std::string::npos) {
|
|
||||||
tokens.push_back(str.substr(start, end - start));
|
|
||||||
start = end + delimiter.length();
|
|
||||||
end = str.find(delimiter, start);
|
|
||||||
}
|
|
||||||
|
|
||||||
tokens.push_back(str.substr(start));
|
|
||||||
|
|
||||||
return tokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::string repeat(const std::string & str, size_t n) {
|
|
||||||
if (n == 0) {
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string result;
|
|
||||||
result.reserve(str.length() * n);
|
|
||||||
|
|
||||||
for (size_t i = 0; i < n; ++i) {
|
|
||||||
result += str;
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch &)> & replacement) {
|
static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch &)> & replacement) {
|
||||||
std::smatch match;
|
std::smatch match;
|
||||||
std::string result;
|
std::string result;
|
||||||
@@ -389,6 +343,7 @@ static std::string format_literal(const std::string & literal) {
|
|||||||
|
|
||||||
class SchemaConverter {
|
class SchemaConverter {
|
||||||
private:
|
private:
|
||||||
|
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
|
||||||
std::function<json(const std::string &)> _fetch_json;
|
std::function<json(const std::string &)> _fetch_json;
|
||||||
bool _dotall;
|
bool _dotall;
|
||||||
std::unordered_map<std::string, std::string> _rules;
|
std::unordered_map<std::string, std::string> _rules;
|
||||||
@@ -418,7 +373,7 @@ private:
|
|||||||
for (size_t i = 0; i < alt_schemas.size(); i++) {
|
for (size_t i = 0; i < alt_schemas.size(); i++) {
|
||||||
rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i)));
|
rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i)));
|
||||||
}
|
}
|
||||||
return join(rules.begin(), rules.end(), " | ");
|
return string_join(rules, " | ");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string _visit_pattern(const std::string & pattern, const std::string & name) {
|
std::string _visit_pattern(const std::string & pattern, const std::string & name) {
|
||||||
@@ -481,7 +436,7 @@ private:
|
|||||||
for (const auto & item : ret) {
|
for (const auto & item : ret) {
|
||||||
results.push_back(to_rule(item));
|
results.push_back(to_rule(item));
|
||||||
}
|
}
|
||||||
return std::make_pair(join(results.begin(), results.end(), " "), false);
|
return std::make_pair(string_join(results, " "), false);
|
||||||
};
|
};
|
||||||
|
|
||||||
while (i < length) {
|
while (i < length) {
|
||||||
@@ -539,7 +494,7 @@ private:
|
|||||||
}
|
}
|
||||||
curly_brackets += '}';
|
curly_brackets += '}';
|
||||||
i++;
|
i++;
|
||||||
auto nums = split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
|
auto nums = string_split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
|
||||||
int min_times = 0;
|
int min_times = 0;
|
||||||
int max_times = std::numeric_limits<int>::max();
|
int max_times = std::numeric_limits<int>::max();
|
||||||
try {
|
try {
|
||||||
@@ -809,10 +764,11 @@ private:
|
|||||||
public:
|
public:
|
||||||
SchemaConverter(
|
SchemaConverter(
|
||||||
const std::function<json(const std::string &)> & fetch_json,
|
const std::function<json(const std::string &)> & fetch_json,
|
||||||
bool dotall)
|
bool dotall,
|
||||||
|
bool compact_spaces)
|
||||||
: _fetch_json(fetch_json), _dotall(dotall)
|
: _fetch_json(fetch_json), _dotall(dotall)
|
||||||
{
|
{
|
||||||
_rules["space"] = SPACE_RULE;
|
_rules["space"] = compact_spaces ? "\" \"?" : SPACE_RULE;
|
||||||
}
|
}
|
||||||
|
|
||||||
void resolve_refs(json & schema, const std::string & url) {
|
void resolve_refs(json & schema, const std::string & url) {
|
||||||
@@ -854,7 +810,7 @@ public:
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
std::string pointer = ref.substr(ref.find('#') + 1);
|
std::string pointer = ref.substr(ref.find('#') + 1);
|
||||||
std::vector<std::string> tokens = split(pointer, "/");
|
std::vector<std::string> tokens = string_split(pointer, "/");
|
||||||
for (size_t i = 1; i < tokens.size(); ++i) {
|
for (size_t i = 1; i < tokens.size(); ++i) {
|
||||||
std::string sel = tokens[i];
|
std::string sel = tokens[i];
|
||||||
if (target.is_null() || !target.contains(sel)) {
|
if (target.is_null() || !target.contains(sel)) {
|
||||||
@@ -905,7 +861,7 @@ public:
|
|||||||
for (const auto & v : schema["enum"]) {
|
for (const auto & v : schema["enum"]) {
|
||||||
enum_values.push_back(_generate_constant_rule(v));
|
enum_values.push_back(_generate_constant_rule(v));
|
||||||
}
|
}
|
||||||
return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space");
|
return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space");
|
||||||
} else if ((schema_type.is_null() || schema_type == "object")
|
} else if ((schema_type.is_null() || schema_type == "object")
|
||||||
&& (schema.contains("properties") ||
|
&& (schema.contains("properties") ||
|
||||||
(schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
|
(schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
|
||||||
@@ -1019,10 +975,10 @@ public:
|
|||||||
|
|
||||||
void check_errors() {
|
void check_errors() {
|
||||||
if (!_errors.empty()) {
|
if (!_errors.empty()) {
|
||||||
throw std::runtime_error("JSON schema conversion failed:\n" + join(_errors.begin(), _errors.end(), "\n"));
|
throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
|
||||||
}
|
}
|
||||||
if (!_warnings.empty()) {
|
if (!_warnings.empty()) {
|
||||||
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", join(_warnings.begin(), _warnings.end(), "; ").c_str());
|
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1035,11 +991,35 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string json_schema_to_grammar(const json & schema) {
|
std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
|
||||||
SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false);
|
#ifdef LLAMA_USE_LLGUIDANCE
|
||||||
|
if (!force_gbnf) {
|
||||||
|
return "%llguidance {}\nstart: %json " + schema.dump();
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
(void)force_gbnf;
|
||||||
|
#endif // LLAMA_USE_LLGUIDANCE
|
||||||
|
return build_grammar([&](const common_grammar_builder & callbacks) {
|
||||||
auto copy = schema;
|
auto copy = schema;
|
||||||
converter.resolve_refs(copy, "input");
|
callbacks.resolve_refs(copy);
|
||||||
converter.visit(copy, "");
|
callbacks.add_schema("", copy);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
|
||||||
|
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall, options.compact_spaces);
|
||||||
|
common_grammar_builder builder {
|
||||||
|
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
|
||||||
|
return converter._add_rule(name, rule);
|
||||||
|
},
|
||||||
|
/* .add_schema = */ [&](const std::string & name, const nlohmann::ordered_json & schema) {
|
||||||
|
return converter.visit(schema, name == "root" ? "" : name);
|
||||||
|
},
|
||||||
|
/* .resolve_refs = */ [&](nlohmann::ordered_json & schema) {
|
||||||
|
converter.resolve_refs(schema, "");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
cb(builder);
|
||||||
converter.check_errors();
|
converter.check_errors();
|
||||||
return converter.format_grammar();
|
return converter.format_grammar();
|
||||||
}
|
}
|
||||||
|
|||||||
16
llama/llama.cpp/common/json-schema-to-grammar.h
vendored
16
llama/llama.cpp/common/json-schema-to-grammar.h
vendored
@@ -5,4 +5,18 @@
|
|||||||
#define JSON_ASSERT GGML_ASSERT
|
#define JSON_ASSERT GGML_ASSERT
|
||||||
#include "json.hpp"
|
#include "json.hpp"
|
||||||
|
|
||||||
std::string json_schema_to_grammar(const nlohmann::ordered_json& schema);
|
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
|
||||||
|
bool force_gbnf = false);
|
||||||
|
|
||||||
|
struct common_grammar_builder {
|
||||||
|
std::function<std::string(const std::string &, const std::string &)> add_rule;
|
||||||
|
std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;
|
||||||
|
std::function<void(nlohmann::ordered_json &)> resolve_refs;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_grammar_options {
|
||||||
|
bool dotall = false;
|
||||||
|
bool compact_spaces = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});
|
||||||
|
|||||||
12
llama/llama.cpp/common/log.cpp
vendored
12
llama/llama.cpp/common/log.cpp
vendored
@@ -1,5 +1,6 @@
|
|||||||
#include "log.h"
|
#include "log.h"
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
#include <condition_variable>
|
#include <condition_variable>
|
||||||
#include <cstdarg>
|
#include <cstdarg>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
@@ -14,16 +15,6 @@ void common_log_set_verbosity_thold(int verbosity) {
|
|||||||
common_log_verbosity_thold = verbosity;
|
common_log_verbosity_thold = verbosity;
|
||||||
}
|
}
|
||||||
|
|
||||||
#define LOG_COL_DEFAULT "\033[0m"
|
|
||||||
#define LOG_COL_BOLD "\033[1m"
|
|
||||||
#define LOG_COL_RED "\033[31m"
|
|
||||||
#define LOG_COL_GREEN "\033[32m"
|
|
||||||
#define LOG_COL_YELLOW "\033[33m"
|
|
||||||
#define LOG_COL_BLUE "\033[34m"
|
|
||||||
#define LOG_COL_MAGENTA "\033[35m"
|
|
||||||
#define LOG_COL_CYAN "\033[36m"
|
|
||||||
#define LOG_COL_WHITE "\033[37m"
|
|
||||||
|
|
||||||
static int64_t t_us() {
|
static int64_t t_us() {
|
||||||
return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
|
return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
|
||||||
}
|
}
|
||||||
@@ -206,6 +197,7 @@ public:
|
|||||||
vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args_copy);
|
vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args_copy);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
va_end(args_copy);
|
||||||
}
|
}
|
||||||
|
|
||||||
entry.level = level;
|
entry.level = level;
|
||||||
|
|||||||
13
llama/llama.cpp/common/log.h
vendored
13
llama/llama.cpp/common/log.h
vendored
@@ -2,9 +2,20 @@
|
|||||||
|
|
||||||
#include "ggml.h" // for ggml_log_level
|
#include "ggml.h" // for ggml_log_level
|
||||||
|
|
||||||
|
#define LOG_CLR_TO_EOL "\033[K\r"
|
||||||
|
#define LOG_COL_DEFAULT "\033[0m"
|
||||||
|
#define LOG_COL_BOLD "\033[1m"
|
||||||
|
#define LOG_COL_RED "\033[31m"
|
||||||
|
#define LOG_COL_GREEN "\033[32m"
|
||||||
|
#define LOG_COL_YELLOW "\033[33m"
|
||||||
|
#define LOG_COL_BLUE "\033[34m"
|
||||||
|
#define LOG_COL_MAGENTA "\033[35m"
|
||||||
|
#define LOG_COL_CYAN "\033[36m"
|
||||||
|
#define LOG_COL_WHITE "\033[37m"
|
||||||
|
|
||||||
#ifndef __GNUC__
|
#ifndef __GNUC__
|
||||||
# define LOG_ATTRIBUTE_FORMAT(...)
|
# define LOG_ATTRIBUTE_FORMAT(...)
|
||||||
#elif defined(__MINGW32__)
|
#elif defined(__MINGW32__) && !defined(__clang__)
|
||||||
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
||||||
#else
|
#else
|
||||||
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
||||||
|
|||||||
48
llama/llama.cpp/common/sampling.cpp
vendored
48
llama/llama.cpp/common/sampling.cpp
vendored
@@ -113,7 +113,10 @@ struct common_sampler {
|
|||||||
void set_logits(struct llama_context * ctx, int idx) {
|
void set_logits(struct llama_context * ctx, int idx) {
|
||||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||||
|
|
||||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
const llama_model * model = llama_get_model(ctx);
|
||||||
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||||
|
|
||||||
|
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||||
|
|
||||||
cur.resize(n_vocab);
|
cur.resize(n_vocab);
|
||||||
|
|
||||||
@@ -131,24 +134,47 @@ std::string common_params_sampling::print() const {
|
|||||||
snprintf(result, sizeof(result),
|
snprintf(result, sizeof(result),
|
||||||
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
||||||
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
|
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
|
||||||
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
|
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n"
|
||||||
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
||||||
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
||||||
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
|
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
|
||||||
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
|
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
|
||||||
mirostat, mirostat_eta, mirostat_tau);
|
mirostat, mirostat_eta, mirostat_tau);
|
||||||
|
|
||||||
return std::string(result);
|
return std::string(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
|
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
|
||||||
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||||
|
|
||||||
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
|
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
|
||||||
|
|
||||||
lparams.no_perf = params.no_perf;
|
lparams.no_perf = params.no_perf;
|
||||||
|
|
||||||
|
struct llama_sampler * grmr;
|
||||||
|
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
||||||
|
#ifdef LLAMA_USE_LLGUIDANCE
|
||||||
|
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
|
||||||
|
#else
|
||||||
|
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||||
|
#endif // LLAMA_USE_LLGUIDANCE
|
||||||
|
} else {
|
||||||
|
std::vector<const char *> trigger_words;
|
||||||
|
trigger_words.reserve(params.grammar_trigger_words.size());
|
||||||
|
for (const auto & str : params.grammar_trigger_words) {
|
||||||
|
trigger_words.push_back(str.word.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
grmr = params.grammar_lazy
|
||||||
|
? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root",
|
||||||
|
trigger_words.data(), trigger_words.size(),
|
||||||
|
params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size())
|
||||||
|
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
||||||
|
}
|
||||||
|
|
||||||
auto * result = new common_sampler {
|
auto * result = new common_sampler {
|
||||||
/* .params = */ params,
|
/* .params = */ params,
|
||||||
/* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
|
/* .grmr = */ grmr,
|
||||||
/* .chain = */ llama_sampler_chain_init(lparams),
|
/* .chain = */ llama_sampler_chain_init(lparams),
|
||||||
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
||||||
/* .cur = */ {},
|
/* .cur = */ {},
|
||||||
@@ -157,11 +183,16 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
|
|
||||||
llama_sampler_chain_add(result->chain,
|
llama_sampler_chain_add(result->chain,
|
||||||
llama_sampler_init_logit_bias(
|
llama_sampler_init_logit_bias(
|
||||||
llama_n_vocab(model),
|
llama_vocab_n_tokens(vocab),
|
||||||
params.logit_bias.size(),
|
params.logit_bias.size(),
|
||||||
params.logit_bias.data()));
|
params.logit_bias.data()));
|
||||||
|
|
||||||
if (params.mirostat == 0) {
|
if (params.mirostat == 0) {
|
||||||
|
if (params.top_n_sigma >= 0) {
|
||||||
|
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
||||||
|
llama_sampler_chain_add(result->chain, llama_sampler_init_temp (params.temp));
|
||||||
|
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
|
||||||
|
} else {
|
||||||
for (const auto & cnstr : params.samplers) {
|
for (const auto & cnstr : params.samplers) {
|
||||||
switch (cnstr) {
|
switch (cnstr) {
|
||||||
case COMMON_SAMPLER_TYPE_DRY:
|
case COMMON_SAMPLER_TYPE_DRY:
|
||||||
@@ -172,7 +203,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
c_breakers.push_back(str.c_str());
|
c_breakers.push_back(str.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||||
@@ -194,7 +225,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_INFILL:
|
case COMMON_SAMPLER_TYPE_INFILL:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||||
@@ -203,10 +234,11 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
GGML_ASSERT(false && "unknown sampler type");
|
GGML_ASSERT(false && "unknown sampler type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
||||||
} else if (params.mirostat == 1) {
|
} else if (params.mirostat == 1) {
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
||||||
} else if (params.mirostat == 2) {
|
} else if (params.mirostat == 2) {
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
||||||
|
|||||||
3
llama/llama.cpp/common/sampling.h
vendored
3
llama/llama.cpp/common/sampling.h
vendored
@@ -102,3 +102,6 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
|
|||||||
|
|
||||||
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
|
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
|
||||||
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
|
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
|
||||||
|
|
||||||
|
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
|
||||||
|
const char * grammar_kind, const char * grammar_data);
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user