Compare commits
413 Commits
pdevine/co
...
parth/pyth
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
23e8ac9428 | ||
|
|
611d3a17ed | ||
|
|
5cfc1c39f3 | ||
|
|
f0ad49ea17 | ||
|
|
7ba9fa9c7d | ||
|
|
8bf11b84c1 | ||
|
|
470af8ab89 | ||
|
|
178761aef3 | ||
|
|
f0c66e6dea | ||
|
|
54055a6dae | ||
|
|
340448d2d1 | ||
|
|
ced7d0e53d | ||
|
|
a0dba0f8ae | ||
|
|
5e20b170a7 | ||
|
|
d26c18e25c | ||
|
|
8d376acc9b | ||
|
|
dc1e81f027 | ||
|
|
5d0279164c | ||
|
|
214a7678ea | ||
|
|
4892872c18 | ||
|
|
0b9198bf47 | ||
|
|
e9e5f61c45 | ||
|
|
11dde41824 | ||
|
|
a53d744b01 | ||
|
|
40b10eee6d | ||
|
|
424f648632 | ||
|
|
2eb1fb3231 | ||
|
|
0806521642 | ||
|
|
88738b357b | ||
|
|
4e535e6188 | ||
|
|
40b8fdbdca | ||
|
|
1d99451ad7 | ||
|
|
09bb2e30f6 | ||
|
|
dc264be6ff | ||
|
|
fbe7039618 | ||
|
|
943464ccb8 | ||
|
|
369de832cd | ||
|
|
3457a315b2 | ||
|
|
ed4e139314 | ||
|
|
56dc316a57 | ||
|
|
2fec73eef6 | ||
|
|
1e7f62cb42 | ||
|
|
ccb7eb8135 | ||
|
|
637fd21230 | ||
|
|
0fe487e732 | ||
|
|
6bfaa6e282 | ||
|
|
378d3210dc | ||
|
|
97fe45e36d | ||
|
|
64a9cc8f05 | ||
|
|
f50d691254 | ||
|
|
34c3b68fc8 | ||
|
|
f33ccd5d27 | ||
|
|
bc108b9ad6 | ||
|
|
ef65174df2 | ||
|
|
42ecb9f138 | ||
|
|
5c0331fd83 | ||
|
|
e7019c9455 | ||
|
|
d98bfe7e70 | ||
|
|
6747099d71 | ||
|
|
ccc8c6777b | ||
|
|
dbb149e6f7 | ||
|
|
a807985e59 | ||
|
|
8643c4d5bf | ||
|
|
b0c3aba590 | ||
|
|
19c0c25de8 | ||
|
|
2f723ac2d6 | ||
|
|
249fbbe52f | ||
|
|
c38680b8a1 | ||
|
|
16fca86c4a | ||
|
|
0f3f9e353d | ||
|
|
6bd0a983cd | ||
|
|
1861fbdeb5 | ||
|
|
3b96a93672 | ||
|
|
e53b3cbd0c | ||
|
|
b51e0f397c | ||
|
|
b42970063d | ||
|
|
493385eb3e | ||
|
|
9876c9faa4 | ||
|
|
4e415029b3 | ||
|
|
e172f095ba | ||
|
|
c001b98087 | ||
|
|
23fc8e92eb | ||
|
|
4059a297a6 | ||
|
|
66b2539238 | ||
|
|
ef27d52e79 | ||
|
|
b2a465296d | ||
|
|
5d097277ef | ||
|
|
071a9872cb | ||
|
|
0bd0454ea7 | ||
|
|
01aa788722 | ||
|
|
ead27aa9fe | ||
|
|
b816ff86c9 | ||
|
|
e5d84fb90b | ||
|
|
dd66712e31 | ||
|
|
f66216e399 | ||
|
|
f4f0992b6e | ||
|
|
1feff61977 | ||
|
|
5e0b904e88 | ||
|
|
131f0355a5 | ||
|
|
ce929984a3 | ||
|
|
4b34930a31 | ||
|
|
74bd09652d | ||
|
|
fb6252d786 | ||
|
|
c794fef2f2 | ||
|
|
00ebda8cc4 | ||
|
|
d14ce75b95 | ||
|
|
2d6eac9084 | ||
|
|
3ed7ad3ab3 | ||
|
|
6d1103048e | ||
|
|
0ff28758b3 | ||
|
|
d3e9ca3eda | ||
|
|
0fbfcf3c9c | ||
|
|
0c220935bd | ||
|
|
ffbfe833da | ||
|
|
42a14f7f63 | ||
|
|
f8c3dbe5b5 | ||
|
|
b078dd157c | ||
|
|
2ddacd7516 | ||
|
|
da0e345200 | ||
|
|
df94175a0f | ||
|
|
61a8825216 | ||
|
|
021dcf089d | ||
|
|
bf24498b1e | ||
|
|
95e271d98f | ||
|
|
364629b8d6 | ||
|
|
108fe02165 | ||
|
|
4561fff36e | ||
|
|
50b5962042 | ||
|
|
e27e4a3c1b | ||
|
|
088514bbd4 | ||
|
|
2c8b484643 | ||
|
|
8294676150 | ||
|
|
ef378ad673 | ||
|
|
2d2247e59e | ||
|
|
7bf793a600 | ||
|
|
282bfaaa95 | ||
|
|
9679f40146 | ||
|
|
3892c3a703 | ||
|
|
4e320b8b90 | ||
|
|
eb2b22b042 | ||
|
|
4ea4d2b189 | ||
|
|
8d76fa23ef | ||
|
|
74b44fdf8f | ||
|
|
65b88c544f | ||
|
|
a422ba39c9 | ||
|
|
d2ec22371e | ||
|
|
033cec232a | ||
|
|
543240fb5f | ||
|
|
4bed739259 | ||
|
|
80c7ce381b | ||
|
|
ccfd41c4f0 | ||
|
|
3e102b7dad | ||
|
|
ec46f3286c | ||
|
|
5e2e0b46b1 | ||
|
|
45a13b1dec | ||
|
|
5c0b663969 | ||
|
|
30d7a59ba8 | ||
|
|
4aeb67ef4c | ||
|
|
3ba91634c1 | ||
|
|
1b7433b71e | ||
|
|
a70820daa0 | ||
|
|
6b45b1d6b4 | ||
|
|
85ab552028 | ||
|
|
b3af953a55 | ||
|
|
ad4e0bf3be | ||
|
|
aee28501b5 | ||
|
|
83f0ec8269 | ||
|
|
c6b6938b3a | ||
|
|
fb4664fcec | ||
|
|
20e3593863 | ||
|
|
63a394068c | ||
|
|
ab39e08eb9 | ||
|
|
11bfa62796 | ||
|
|
f63e62e546 | ||
|
|
65b0f329d1 | ||
|
|
06007c0a18 | ||
|
|
a8e83a7654 | ||
|
|
475005504e | ||
|
|
2c40c4d35e | ||
|
|
e95278932b | ||
|
|
9d2a20a763 | ||
|
|
2e54d72fc3 | ||
|
|
6b32a2d549 | ||
|
|
c5cbe4fc2a | ||
|
|
f888912870 | ||
|
|
9e4642e9b3 | ||
|
|
6b0486c216 | ||
|
|
d368c039f0 | ||
|
|
9b54267e69 | ||
|
|
46bb0169c4 | ||
|
|
8934324b72 | ||
|
|
0e886595bf | ||
|
|
c62861f4fa | ||
|
|
0df1800436 | ||
|
|
631fecc6d9 | ||
|
|
4346c2409d | ||
|
|
4b037a97dc | ||
|
|
5f74d1fd47 | ||
|
|
4dcf80167a | ||
|
|
26a26998fb | ||
|
|
9926eae015 | ||
|
|
8585b7b151 | ||
|
|
7e34f4fbfa | ||
|
|
fe776293f7 | ||
|
|
d8a5d96b98 | ||
|
|
757668c42f | ||
|
|
96ec8afd09 | ||
|
|
e093db92c4 | ||
|
|
a1cda80bcb | ||
|
|
4614fafae0 | ||
|
|
4100ed7bdd | ||
|
|
f52b2615ef | ||
|
|
25f9b152f9 | ||
|
|
6da8b6a879 | ||
|
|
0daaaef8c9 | ||
|
|
98272fbd58 | ||
|
|
b27e8f3f10 | ||
|
|
45df786f09 | ||
|
|
daaf42e4a4 | ||
|
|
2dc60d4620 | ||
|
|
b5312f30e8 | ||
|
|
26c2e0bd35 | ||
|
|
bf920883d5 | ||
|
|
58b9ec1f6b | ||
|
|
7bae7fa5ce | ||
|
|
764e199d67 | ||
|
|
bfce55db3d | ||
|
|
bab6f34dc0 | ||
|
|
0682dae027 | ||
|
|
1f6986e919 | ||
|
|
4289c74359 | ||
|
|
25248f4bd5 | ||
|
|
a7e63b82be | ||
|
|
b70fc4d51e | ||
|
|
e2252d0fc6 | ||
|
|
cae5d4d4ea | ||
|
|
05a01fdecb | ||
|
|
8fe6f69f28 | ||
|
|
1fdb351c37 | ||
|
|
7a01ad7614 | ||
|
|
55ab9f371a | ||
|
|
fefbf8f74b | ||
|
|
b428ddd796 | ||
|
|
ba7d31240e | ||
|
|
d25efe3954 | ||
|
|
36dfb906bb | ||
|
|
a6f0f908b9 | ||
|
|
3b1ddb2b3a | ||
|
|
1579c4f06d | ||
|
|
3519dd1c6e | ||
|
|
e41c4cbea7 | ||
|
|
ee048b76d4 | ||
|
|
af68d60a58 | ||
|
|
21aa666a1e | ||
|
|
ee141cc821 | ||
|
|
55e5776c44 | ||
|
|
854a9195f3 | ||
|
|
96a97adf9b | ||
|
|
e75c6126e9 | ||
|
|
cda6f5c66c | ||
|
|
bebb6823c0 | ||
|
|
31e472baa4 | ||
|
|
657685e85d | ||
|
|
a14912858e | ||
|
|
eed11ded30 | ||
|
|
b42aba40ed | ||
|
|
25885e5335 | ||
|
|
98d44fa39d | ||
|
|
2099e2d267 | ||
|
|
0c1041ad85 | ||
|
|
c245b0406f | ||
|
|
8b194b7520 | ||
|
|
3e8b8a1933 | ||
|
|
41dc280491 | ||
|
|
53d2990d9b | ||
|
|
e185c08ad9 | ||
|
|
2412adf42b | ||
|
|
be2ac1ed93 | ||
|
|
dc13813a03 | ||
|
|
d6af13efed | ||
|
|
a59f665235 | ||
|
|
688925aca9 | ||
|
|
76e903cf9d | ||
|
|
a5272130c4 | ||
|
|
d7d7e99662 | ||
|
|
2db96c18e7 | ||
|
|
e12af460ed | ||
|
|
3ad4bc8afe | ||
|
|
0d694793f2 | ||
|
|
e91ae3d47d | ||
|
|
6ecd7f64ba | ||
|
|
888855675e | ||
|
|
b16367b4b2 | ||
|
|
a499390648 | ||
|
|
4df98f3eb5 | ||
|
|
348b3e0983 | ||
|
|
0b7e1676eb | ||
|
|
314573bfe8 | ||
|
|
4604b10306 | ||
|
|
8c13cfa4dd | ||
|
|
7cfd4aee4d | ||
|
|
68bac1e0a6 | ||
|
|
f53f4198c3 | ||
|
|
2192a28eed | ||
|
|
5d81c1a184 | ||
|
|
5c5535c064 | ||
|
|
e5bcc51ae1 | ||
|
|
bd6a7d5e64 | ||
|
|
14b5a9a150 | ||
|
|
ba9ec3d05e | ||
|
|
7c168b08c9 | ||
|
|
3d4cc7833c | ||
|
|
351a85d9ea | ||
|
|
bda4ef6c56 | ||
|
|
1e438b237c | ||
|
|
d721a02e7d | ||
|
|
778603a818 | ||
|
|
3c874df46e | ||
|
|
d2eb226c91 | ||
|
|
e13e7c8d94 | ||
|
|
78f403ff45 | ||
|
|
5f8c03189e | ||
|
|
08a299e1d0 | ||
|
|
7b5d916a9a | ||
|
|
33ad61b112 | ||
|
|
716e365615 | ||
|
|
3b4424ff98 | ||
|
|
f9c7ead160 | ||
|
|
5930aaeb1a | ||
|
|
faf67db089 | ||
|
|
0667baddc6 | ||
|
|
d006e1e09b | ||
|
|
df2680b4b9 | ||
|
|
010313bb63 | ||
|
|
5296f487a8 | ||
|
|
f05774b04c | ||
|
|
6600bd7d91 | ||
|
|
ed443a0393 | ||
|
|
6945617af5 | ||
|
|
7916f55009 | ||
|
|
d650ad398f | ||
|
|
d223f3b697 | ||
|
|
60830695c2 | ||
|
|
01d9a46854 | ||
|
|
d773b7d671 | ||
|
|
4d4463b2bd | ||
|
|
0e38297f87 | ||
|
|
7e13f568dc | ||
|
|
58245413f4 | ||
|
|
8cf16063a5 | ||
|
|
3a4449e2f1 | ||
|
|
10d59d5f90 | ||
|
|
a4f69a0191 | ||
|
|
82658c3eec | ||
|
|
378d6e1e6a | ||
|
|
afa55bc70c | ||
|
|
49df03da9a | ||
|
|
0189bdd0b7 | ||
|
|
f4711da7bd | ||
|
|
38117fba83 | ||
|
|
1f766c36fb | ||
|
|
484a99e428 | ||
|
|
ec6121c331 | ||
|
|
b86c0a1500 | ||
|
|
7e402ebb8c | ||
|
|
b901a712c6 | ||
|
|
abb8dd57f8 | ||
|
|
a400df48c0 | ||
|
|
6ab4ba4c26 | ||
|
|
e8d4eb3e68 | ||
|
|
ae7e368f75 | ||
|
|
31acd1ebf9 | ||
|
|
9a4757ae66 | ||
|
|
7814019708 | ||
|
|
b698f9a0d8 | ||
|
|
32285a6d19 | ||
|
|
1c198977ec | ||
|
|
330b6c50b0 | ||
|
|
928911bc68 | ||
|
|
5b446cc815 | ||
|
|
451c1596af | ||
|
|
932bded12f | ||
|
|
070ad913ac | ||
|
|
8d8b9f83ae | ||
|
|
f00d359a67 | ||
|
|
291def6adb | ||
|
|
cd3fbf1c49 | ||
|
|
c852b8e021 | ||
|
|
d8932c55e7 | ||
|
|
63f0269f7f | ||
|
|
4759ecae19 | ||
|
|
65b7ecac7b | ||
|
|
f9d2d89135 | ||
|
|
669dc31cf3 | ||
|
|
d4d338c224 | ||
|
|
bfdeffc375 | ||
|
|
e806184023 | ||
|
|
50566113ac | ||
|
|
ad22ace439 | ||
|
|
f4321a421c | ||
|
|
475333d533 | ||
|
|
39fd89308c | ||
|
|
548a9f56a6 | ||
|
|
3f0cb36bdb | ||
|
|
bea1f1fac6 | ||
|
|
5d75d837ef | ||
|
|
711648c9bb | ||
|
|
dcfb7a105c | ||
|
|
2ef3c803a1 | ||
|
|
453e4d090b | ||
|
|
ca2f9843c8 | ||
|
|
294b6f5a22 | ||
|
|
7bb356c680 |
@@ -3,7 +3,9 @@ ollama
|
|||||||
app
|
app
|
||||||
macapp
|
macapp
|
||||||
dist
|
dist
|
||||||
|
build
|
||||||
.env
|
.env
|
||||||
.cache
|
.cache
|
||||||
test_data
|
test_data
|
||||||
llama/build
|
.git
|
||||||
|
|
||||||
|
|||||||
13
.gitattributes
vendored
13
.gitattributes
vendored
@@ -7,5 +7,18 @@ llama/**/*.cuh linguist-vendored
|
|||||||
llama/**/*.m linguist-vendored
|
llama/**/*.m linguist-vendored
|
||||||
llama/**/*.metal linguist-vendored
|
llama/**/*.metal linguist-vendored
|
||||||
|
|
||||||
|
ml/backend/**/*.c linguist-vendored
|
||||||
|
ml/backend/**/*.h linguist-vendored
|
||||||
|
ml/backend/**/*.cpp linguist-vendored
|
||||||
|
ml/backend/**/*.hpp linguist-vendored
|
||||||
|
ml/backend/**/*.cu linguist-vendored
|
||||||
|
ml/backend/**/*.cuh linguist-vendored
|
||||||
|
ml/backend/**/*.m linguist-vendored
|
||||||
|
ml/backend/**/*.metal linguist-vendored
|
||||||
|
ml/backend/**/CMakeLists.txt linguist-vendored
|
||||||
|
|
||||||
|
llama/build-info.cpp linguist-generated
|
||||||
|
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.s linguist-generated
|
||||||
|
|
||||||
* text=auto
|
* text=auto
|
||||||
*.go text eol=lf
|
*.go text eol=lf
|
||||||
|
|||||||
8
.github/ISSUE_TEMPLATE/10_bug_report.yml
vendored
8
.github/ISSUE_TEMPLATE/10_bug_report.yml
vendored
@@ -9,6 +9,14 @@ body:
|
|||||||
description: What happened? What did you expect to happen?
|
description: What happened? What did you expect to happen?
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
id: logs
|
||||||
|
attributes:
|
||||||
|
label: Relevant log output
|
||||||
|
description: Please copy and paste any relevant log output. See [Troubleshooting Guide](https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md#how-to-troubleshoot-issues) for details.
|
||||||
|
render: shell
|
||||||
|
validations:
|
||||||
|
required: false
|
||||||
- type: dropdown
|
- type: dropdown
|
||||||
id: os
|
id: os
|
||||||
attributes:
|
attributes:
|
||||||
|
|||||||
1027
.github/workflows/release.yaml
vendored
1027
.github/workflows/release.yaml
vendored
File diff suppressed because it is too large
Load Diff
454
.github/workflows/test.yaml
vendored
454
.github/workflows/test.yaml
vendored
@@ -1,11 +1,5 @@
|
|||||||
name: test
|
name: test
|
||||||
|
|
||||||
env:
|
|
||||||
ROCM_WINDOWS_URL: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe
|
|
||||||
MSYS2_URL: https://github.com/msys2/msys2-installer/releases/download/2024-07-27/msys2-x86_64-20240727.exe
|
|
||||||
CUDA_12_WINDOWS_URL: https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_551.61_windows.exe
|
|
||||||
CUDA_12_WINDOWS_VER: 12.4
|
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
# For PRs, later CI runs preempt previous ones. e.g. a force push on a PR
|
# For PRs, later CI runs preempt previous ones. e.g. a force push on a PR
|
||||||
# cancels running CI jobs and starts all new ones.
|
# cancels running CI jobs and starts all new ones.
|
||||||
@@ -27,7 +21,7 @@ jobs:
|
|||||||
changes:
|
changes:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
outputs:
|
outputs:
|
||||||
RUNNERS: ${{ steps.changes.outputs.RUNNERS }}
|
changed: ${{ steps.changes.outputs.changed }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
@@ -35,309 +29,213 @@ jobs:
|
|||||||
- id: changes
|
- id: changes
|
||||||
run: |
|
run: |
|
||||||
changed() {
|
changed() {
|
||||||
git diff-tree -r --no-commit-id --name-only \
|
local BASE=${{ github.event.pull_request.base.sha }}
|
||||||
$(git merge-base ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }}) \
|
local HEAD=${{ github.event.pull_request.head.sha }}
|
||||||
${{ github.event.pull_request.head.sha }} \
|
local MERGE_BASE=$(git merge-base $BASE $HEAD)
|
||||||
|
git diff-tree -r --no-commit-id --name-only "$MERGE_BASE" "$HEAD" \
|
||||||
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
|
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
echo changed=$(changed 'llama/llama.cpp/**' 'ml/backend/ggml/ggml/**') | tee -a $GITHUB_OUTPUT
|
||||||
echo RUNNERS=$(changed 'llama/**')
|
|
||||||
} >>$GITHUB_OUTPUT
|
|
||||||
|
|
||||||
runners-linux-cuda:
|
linux:
|
||||||
needs: [changes]
|
needs: [changes]
|
||||||
if: ${{ needs.changes.outputs.RUNNERS == 'True' }}
|
if: needs.changes.outputs.changed == 'True'
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
cuda-version:
|
include:
|
||||||
- '11.8.0'
|
- preset: CPU
|
||||||
|
- preset: CUDA
|
||||||
|
container: nvidia/cuda:11.8.0-devel-ubuntu22.04
|
||||||
|
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
||||||
|
- preset: ROCm
|
||||||
|
container: rocm/dev-ubuntu-22.04:6.1.2
|
||||||
|
extra-packages: rocm-libs
|
||||||
|
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm'
|
||||||
runs-on: linux
|
runs-on: linux
|
||||||
container: nvidia/cuda:${{ matrix.cuda-version }}-devel-ubuntu20.04
|
container: ${{ matrix.container }}
|
||||||
steps:
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
- run: |
|
- run: |
|
||||||
apt-get update && apt-get install -y git build-essential curl
|
[ -n "${{ matrix.container }}" ] || sudo=sudo
|
||||||
|
$sudo apt-get update
|
||||||
|
$sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }}
|
||||||
env:
|
env:
|
||||||
DEBIAN_FRONTEND: noninteractive
|
DEBIAN_FRONTEND: noninteractive
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/cache@v4
|
||||||
- uses: actions/setup-go@v4
|
|
||||||
with:
|
with:
|
||||||
go-version-file: go.mod
|
path: /github/home/.cache/ccache
|
||||||
cache: true
|
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
|
||||||
- run: go get ./...
|
|
||||||
- run: |
|
- run: |
|
||||||
git config --global --add safe.directory /__w/ollama/ollama
|
cmake --preset ${{ matrix.preset }} ${{ matrix.flags }}
|
||||||
cores=$(grep '^core id' /proc/cpuinfo |sort -u|wc -l)
|
cmake --build --preset ${{ matrix.preset }} --parallel
|
||||||
make -j $cores cuda_v11
|
|
||||||
runners-linux-rocm:
|
windows:
|
||||||
needs: [changes]
|
needs: [changes]
|
||||||
if: ${{ needs.changes.outputs.RUNNERS == 'True' }}
|
if: needs.changes.outputs.changed == 'True'
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
rocm-version:
|
include:
|
||||||
- '6.1.2'
|
- preset: CPU
|
||||||
runs-on: linux
|
- preset: CUDA
|
||||||
container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }}
|
install: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
|
||||||
steps:
|
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
|
||||||
- run: |
|
- preset: ROCm
|
||||||
apt-get update && apt-get install -y git build-essential curl rocm-libs
|
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
||||||
env:
|
flags: '-DAMDGPU_TARGETS=gfx1010'
|
||||||
DEBIAN_FRONTEND: noninteractive
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
- uses: actions/setup-go@v4
|
|
||||||
with:
|
|
||||||
go-version-file: go.mod
|
|
||||||
cache: true
|
|
||||||
- run: go get ./...
|
|
||||||
- run: |
|
|
||||||
git config --global --add safe.directory /__w/ollama/ollama
|
|
||||||
cores=$(grep '^core id' /proc/cpuinfo |sort -u|wc -l)
|
|
||||||
make -j $cores rocm
|
|
||||||
|
|
||||||
# ROCm generation step
|
|
||||||
runners-windows-rocm:
|
|
||||||
needs: [changes]
|
|
||||||
if: ${{ needs.changes.outputs.RUNNERS == 'True' }}
|
|
||||||
runs-on: windows
|
runs-on: windows
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- run: |
|
||||||
- uses: actions/setup-go@v5
|
choco install -y --no-progress ccache ninja
|
||||||
|
ccache -o cache_dir=${{ github.workspace }}\.ccache
|
||||||
|
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm'
|
||||||
|
id: cache-install
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
with:
|
with:
|
||||||
go-version-file: go.mod
|
path: |
|
||||||
cache: true
|
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||||
- name: Set make jobs default
|
C:\Program Files\AMD\ROCm
|
||||||
run: |
|
key: ${{ matrix.install }}
|
||||||
echo "MAKEFLAGS=--jobs=$((Get-ComputerInfo -Property CsProcessors).CsProcessors.NumberOfCores)" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
|
- if: matrix.preset == 'CUDA'
|
||||||
|
name: Install CUDA ${{ matrix.cuda-version }}
|
||||||
# ROCM installation steps
|
|
||||||
- name: 'Cache ROCm installer'
|
|
||||||
id: cache-rocm
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
|
||||||
path: rocm-install.exe
|
|
||||||
key: ${{ env.ROCM_WINDOWS_URL }}
|
|
||||||
- name: 'Conditionally Download ROCm'
|
|
||||||
if: steps.cache-rocm.outputs.cache-hit != 'true'
|
|
||||||
run: |
|
run: |
|
||||||
$ErrorActionPreference = "Stop"
|
$ErrorActionPreference = "Stop"
|
||||||
Invoke-WebRequest -Uri "${env:ROCM_WINDOWS_URL}" -OutFile "rocm-install.exe"
|
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||||
- name: 'Install ROCm'
|
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
||||||
run: |
|
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_11.3", "nvcc_11.3", "cublas_11.3", "cublas_dev_11.3")) -NoNewWindow -Wait
|
||||||
Start-Process "rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
|
}
|
||||||
- name: 'Verify ROCm'
|
|
||||||
run: |
|
|
||||||
& 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version
|
|
||||||
echo "HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path | select -first 1)" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
|
|
||||||
|
|
||||||
- name: Add msys paths
|
$cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path
|
||||||
run: |
|
|
||||||
echo "c:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
|
||||||
echo "C:\msys64\clang64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
|
||||||
- name: Install msys2 tools
|
|
||||||
run: |
|
|
||||||
Start-Process "c:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang") -NoNewWindow -Wait
|
|
||||||
|
|
||||||
- name: make rocm runner
|
|
||||||
run: |
|
|
||||||
import-module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
|
||||||
Enter-VsDevShell -vsinstallpath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -skipautomaticlocation -DevCmdArguments '-arch=x64 -no_logo'
|
|
||||||
if (!(gcc --version | select-string -quiet clang)) { throw "wrong gcc compiler detected - must be clang" }
|
|
||||||
make -C llama print-HIP_PATH print-HIP_LIB_DIR
|
|
||||||
make rocm
|
|
||||||
|
|
||||||
# CUDA generation step
|
|
||||||
runners-windows-cuda:
|
|
||||||
needs: [changes]
|
|
||||||
if: ${{ needs.changes.outputs.RUNNERS == 'True' }}
|
|
||||||
runs-on: windows
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
- uses: actions/setup-go@v5
|
|
||||||
with:
|
|
||||||
go-version-file: go.mod
|
|
||||||
cache: true
|
|
||||||
- name: Set make jobs default
|
|
||||||
run: |
|
|
||||||
echo "MAKEFLAGS=--jobs=$((Get-ComputerInfo -Property CsProcessors).CsProcessors.NumberOfCores)" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
|
|
||||||
|
|
||||||
# CUDA installation steps
|
|
||||||
- name: 'Cache CUDA installer'
|
|
||||||
id: cache-cuda
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
|
||||||
path: cuda-install.exe
|
|
||||||
key: ${{ env.CUDA_12_WINDOWS_URL }}
|
|
||||||
- name: 'Conditionally Download CUDA'
|
|
||||||
if: steps.cache-cuda.outputs.cache-hit != 'true'
|
|
||||||
run: |
|
|
||||||
$ErrorActionPreference = "Stop"
|
|
||||||
Invoke-WebRequest -Uri "${env:CUDA_12_WINDOWS_URL}" -OutFile "cuda-install.exe"
|
|
||||||
- name: 'Install CUDA'
|
|
||||||
run: |
|
|
||||||
$subpackages = @("cudart", "nvcc", "cublas", "cublas_dev") | foreach-object {"${_}_${{ env.CUDA_12_WINDOWS_VER }}"}
|
|
||||||
Start-Process "cuda-install.exe" -ArgumentList (@("-s") + $subpackages) -NoNewWindow -Wait
|
|
||||||
- name: 'Verify CUDA'
|
|
||||||
run: |
|
|
||||||
& (resolve-path "c:\Program Files\NVIDIA*\CUDA\v*\bin\nvcc.exe")[0] --version
|
|
||||||
$cudaPath=((resolve-path "c:\Program Files\NVIDIA*\CUDA\v*\bin\nvcc.exe")[0].path | split-path | split-path)
|
|
||||||
$cudaVer=($cudaPath | split-path -leaf ) -replace 'v(\d+).(\d+)', '$1_$2'
|
|
||||||
echo "$cudaPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
echo "$cudaPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||||
echo "CUDA_PATH=$cudaPath" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
|
- if: matrix.preset == 'ROCm'
|
||||||
echo "CUDA_PATH_V${cudaVer}=$cudaPath" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
|
name: Install ROCm ${{ matrix.rocm-version }}
|
||||||
echo "CUDA_PATH_VX_Y=CUDA_PATH_V${cudaVer}" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
|
run: |
|
||||||
|
$ErrorActionPreference = "Stop"
|
||||||
|
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||||
|
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
||||||
|
Start-Process -FilePath .\install.exe -ArgumentList '-install' -NoNewWindow -Wait
|
||||||
|
}
|
||||||
|
|
||||||
- name: Add msys paths
|
$hipPath = (Resolve-Path "C:\Program Files\AMD\ROCm\*").path
|
||||||
run: |
|
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||||
echo "c:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
echo "C:\msys64\clang64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
- name: Install msys2 tools
|
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
||||||
run: |
|
uses: actions/cache/save@v4
|
||||||
Start-Process "c:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang") -NoNewWindow -Wait
|
with:
|
||||||
- name: make cuda runner
|
path: |
|
||||||
run: |
|
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
|
||||||
import-module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
C:\Program Files\AMD\ROCm
|
||||||
Enter-VsDevShell -vsinstallpath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -skipautomaticlocation -DevCmdArguments '-arch=x64 -no_logo'
|
key: ${{ matrix.install }}
|
||||||
if (!(gcc --version | select-string -quiet clang)) { throw "wrong gcc compiler detected - must be clang" }
|
|
||||||
make cuda_v$(($env:CUDA_PATH | split-path -leaf) -replace 'v(\d+).*', '$1')
|
|
||||||
|
|
||||||
runners-cpu:
|
|
||||||
needs: [changes]
|
|
||||||
if: ${{ needs.changes.outputs.RUNNERS == 'True' }}
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, macos-latest, windows-2019]
|
|
||||||
arch: [amd64, arm64]
|
|
||||||
exclude:
|
|
||||||
- os: ubuntu-latest
|
|
||||||
arch: arm64
|
|
||||||
- os: windows-2019
|
|
||||||
arch: arm64
|
|
||||||
runs-on: ${{ matrix.os }}
|
|
||||||
env:
|
|
||||||
GOARCH: ${{ matrix.arch }}
|
|
||||||
ARCH: ${{ matrix.arch }}
|
|
||||||
CGO_ENABLED: '1'
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
go-version-file: go.mod
|
path: ${{ github.workspace }}\.ccache
|
||||||
cache: true
|
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
|
||||||
- name: Add msys paths
|
|
||||||
if: ${{ startsWith(matrix.os, 'windows-') }}
|
|
||||||
run: |
|
|
||||||
echo "c:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
|
||||||
echo "C:\msys64\clang64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
|
||||||
- name: Install msys2 tools
|
|
||||||
if: ${{ startsWith(matrix.os, 'windows-') }}
|
|
||||||
run: |
|
|
||||||
Start-Process "c:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang") -NoNewWindow -Wait
|
|
||||||
- name: 'Build Windows Go Runners'
|
|
||||||
if: ${{ startsWith(matrix.os, 'windows-') }}
|
|
||||||
run: |
|
|
||||||
$gopath=(get-command go).source | split-path -parent
|
|
||||||
$gccpath=(get-command gcc).source | split-path -parent
|
|
||||||
import-module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
|
||||||
Enter-VsDevShell -vsinstallpath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -skipautomaticlocation -DevCmdArguments '-arch=x64 -no_logo'
|
|
||||||
$env:CMAKE_SYSTEM_VERSION="10.0.22621.0"
|
|
||||||
$env:PATH="$gopath;$gccpath;$env:PATH"
|
|
||||||
echo $env:PATH
|
|
||||||
if (!(gcc --version | select-string -quiet clang)) { throw "wrong gcc compiler detected - must be clang" }
|
|
||||||
make -j 4
|
|
||||||
- name: 'Build Unix Go Runners'
|
|
||||||
if: ${{ ! startsWith(matrix.os, 'windows-') }}
|
|
||||||
run: make -j 4
|
|
||||||
- run: go build .
|
|
||||||
|
|
||||||
lint:
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, macos-latest, windows-2019]
|
|
||||||
arch: [amd64, arm64]
|
|
||||||
exclude:
|
|
||||||
- os: ubuntu-latest
|
|
||||||
arch: arm64
|
|
||||||
- os: windows-2019
|
|
||||||
arch: arm64
|
|
||||||
- os: macos-latest
|
|
||||||
arch: amd64
|
|
||||||
runs-on: ${{ matrix.os }}
|
|
||||||
env:
|
|
||||||
GOARCH: ${{ matrix.arch }}
|
|
||||||
CGO_ENABLED: '1'
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
submodules: recursive
|
|
||||||
- name: Add msys paths
|
|
||||||
if: ${{ startsWith(matrix.os, 'windows-') }}
|
|
||||||
run: |
|
|
||||||
echo "c:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
|
||||||
echo "C:\msys64\clang64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
|
||||||
- name: Install msys2 tools
|
|
||||||
if: ${{ startsWith(matrix.os, 'windows-') }}
|
|
||||||
run: |
|
|
||||||
Start-Process "c:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang") -NoNewWindow -Wait
|
|
||||||
- uses: actions/setup-go@v5
|
|
||||||
with:
|
|
||||||
go-version-file: go.mod
|
|
||||||
cache: false
|
|
||||||
- run: |
|
- run: |
|
||||||
case ${{ matrix.arch }} in
|
Import-Module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||||
amd64) echo ARCH=x86_64 ;;
|
Enter-VsDevShell -VsInstallPath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||||
arm64) echo ARCH=arm64 ;;
|
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
|
||||||
esac >>$GITHUB_ENV
|
cmake --build --parallel --preset "${{ matrix.preset }}"
|
||||||
shell: bash
|
|
||||||
- uses: golangci/golangci-lint-action@v6
|
|
||||||
with:
|
|
||||||
args: --timeout 10m0s -v
|
|
||||||
test:
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, macos-latest, windows-2019]
|
|
||||||
arch: [amd64]
|
|
||||||
exclude:
|
|
||||||
- os: ubuntu-latest
|
|
||||||
arch: arm64
|
|
||||||
- os: windows-2019
|
|
||||||
arch: arm64
|
|
||||||
runs-on: ${{ matrix.os }}
|
|
||||||
env:
|
env:
|
||||||
GOARCH: ${{ matrix.arch }}
|
CMAKE_GENERATOR: Ninja
|
||||||
CGO_ENABLED: '1'
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
submodules: recursive
|
|
||||||
- name: Add msys paths
|
|
||||||
if: ${{ startsWith(matrix.os, 'windows-') }}
|
|
||||||
run: |
|
|
||||||
echo "c:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
|
||||||
echo "C:\msys64\clang64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
|
||||||
- name: Install msys2 tools
|
|
||||||
if: ${{ startsWith(matrix.os, 'windows-') }}
|
|
||||||
run: |
|
|
||||||
Start-Process "c:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang") -NoNewWindow -Wait
|
|
||||||
- uses: actions/setup-go@v5
|
|
||||||
with:
|
|
||||||
go-version-file: go.mod
|
|
||||||
cache: true
|
|
||||||
- run: |
|
|
||||||
case ${{ matrix.arch }} in
|
|
||||||
amd64) echo ARCH=amd64 ;;
|
|
||||||
arm64) echo ARCH=arm64 ;;
|
|
||||||
esac >>$GITHUB_ENV
|
|
||||||
shell: bash
|
|
||||||
- run: go test ./...
|
|
||||||
|
|
||||||
patches:
|
go_mod_tidy:
|
||||||
needs: [changes]
|
|
||||||
if: ${{ needs.changes.outputs.RUNNERS == 'True' }}
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
- name: check that 'go mod tidy' is clean
|
||||||
|
run: go mod tidy --diff || (echo "Please run 'go mod tidy'." && exit 1)
|
||||||
|
|
||||||
|
test:
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest, macos-latest, windows-latest]
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
env:
|
||||||
|
CGO_ENABLED: '1'
|
||||||
|
GOEXPERIMENT: 'synctest'
|
||||||
|
steps:
|
||||||
|
- name: checkout
|
||||||
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2
|
||||||
|
|
||||||
|
- name: cache restore
|
||||||
|
uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
|
||||||
with:
|
with:
|
||||||
submodules: recursive
|
# Note: unlike the other setups, this is only grabbing the mod download
|
||||||
- name: Verify patches carry all the changes
|
# cache, rather than the whole mod directory, as the download cache
|
||||||
|
# contains zips that can be unpacked in parallel faster than they can be
|
||||||
|
# fetched and extracted by tar
|
||||||
|
path: |
|
||||||
|
~/.cache/go-build
|
||||||
|
~/go/pkg/mod/cache
|
||||||
|
~\AppData\Local\go-build
|
||||||
|
# NOTE: The -3- here should be incremented when the scheme of data to be
|
||||||
|
# cached changes (e.g. path above changes).
|
||||||
|
key: ${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-${{ hashFiles('**/go.sum') }}-${{ github.run_id }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-${{ hashFiles('**/go.sum') }}
|
||||||
|
${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-
|
||||||
|
|
||||||
|
- name: Setup Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
# The caching strategy of setup-go is less than ideal, and wastes
|
||||||
|
# time by not saving artifacts due to small failures like the linter
|
||||||
|
# complaining, etc. This means subsequent have to rebuild their world
|
||||||
|
# again until all checks pass. For instance, if you mispell a word,
|
||||||
|
# you're punished until you fix it. This is more hostile than
|
||||||
|
# helpful.
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
go-version-file: go.mod
|
||||||
|
|
||||||
|
# It is tempting to run this in a platform independent way, but the past
|
||||||
|
# shows this codebase will see introductions of platform specific code
|
||||||
|
# generation, and so we need to check this per platform to ensure we
|
||||||
|
# don't abuse go generate on specific platforms.
|
||||||
|
- name: check that 'go generate' is clean
|
||||||
|
if: always()
|
||||||
run: |
|
run: |
|
||||||
make apply-patches sync && git diff --compact-summary --exit-code llama
|
go generate ./...
|
||||||
|
git diff --name-only --exit-code || (echo "Please run 'go generate ./...'." && exit 1)
|
||||||
|
|
||||||
|
- name: go test
|
||||||
|
if: always()
|
||||||
|
run: go test -count=1 -benchtime=1x ./...
|
||||||
|
|
||||||
|
# TODO(bmizerany): replace this heavy tool with just the
|
||||||
|
# tools/checks/binaries we want and then make them all run in parallel
|
||||||
|
# across jobs, not on a single tiny vm on Github Actions.
|
||||||
|
- uses: golangci/golangci-lint-action@v6
|
||||||
|
with:
|
||||||
|
args: --timeout 10m0s -v
|
||||||
|
|
||||||
|
- name: cache save
|
||||||
|
# Always save the cache, even if the job fails. The artifacts produced
|
||||||
|
# during the building of test binaries are not all for naught. They can
|
||||||
|
# be used to speed up subsequent runs.
|
||||||
|
if: always()
|
||||||
|
|
||||||
|
uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
|
||||||
|
with:
|
||||||
|
# Note: unlike the other setups, this is only grabbing the mod download
|
||||||
|
# cache, rather than the whole mod directory, as the download cache
|
||||||
|
# contains zips that can be unpacked in parallel faster than they can be
|
||||||
|
# fetched and extracted by tar
|
||||||
|
path: |
|
||||||
|
~/.cache/go-build
|
||||||
|
~/go/pkg/mod/cache
|
||||||
|
~\AppData\Local\go-build
|
||||||
|
# NOTE: The -3- here should be incremented when the scheme of data to be
|
||||||
|
# cached changes (e.g. path above changes).
|
||||||
|
key: ${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-3-${{ hashFiles('**/go.sum') }}-${{ github.run_id }}
|
||||||
|
|
||||||
|
patches:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Verify patches apply cleanly and do not change files
|
||||||
|
run: |
|
||||||
|
make -f Makefile.sync clean checkout apply-patches sync
|
||||||
|
git diff --compact-summary --exit-code
|
||||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -4,12 +4,13 @@
|
|||||||
.venv
|
.venv
|
||||||
.swp
|
.swp
|
||||||
dist
|
dist
|
||||||
ollama
|
build
|
||||||
.cache
|
.cache
|
||||||
*.exe
|
*.exe
|
||||||
.idea
|
.idea
|
||||||
test_data
|
test_data
|
||||||
*.crt
|
*.crt
|
||||||
llama/build
|
|
||||||
__debug_bin*
|
__debug_bin*
|
||||||
|
llama/build
|
||||||
llama/vendor
|
llama/vendor
|
||||||
|
/ollama
|
||||||
|
|||||||
@@ -6,8 +6,6 @@ linters:
|
|||||||
- bidichk
|
- bidichk
|
||||||
- bodyclose
|
- bodyclose
|
||||||
- containedctx
|
- containedctx
|
||||||
- contextcheck
|
|
||||||
- errcheck
|
|
||||||
- gocheckcompilerdirectives
|
- gocheckcompilerdirectives
|
||||||
- gofmt
|
- gofmt
|
||||||
- gofumpt
|
- gofumpt
|
||||||
@@ -23,10 +21,11 @@ linters:
|
|||||||
- staticcheck
|
- staticcheck
|
||||||
- tenv
|
- tenv
|
||||||
- unconvert
|
- unconvert
|
||||||
- unused
|
|
||||||
- usestdlibvars
|
|
||||||
- wastedassign
|
- wastedassign
|
||||||
- whitespace
|
- whitespace
|
||||||
|
disable:
|
||||||
|
- usestdlibvars
|
||||||
|
- errcheck
|
||||||
linters-settings:
|
linters-settings:
|
||||||
staticcheck:
|
staticcheck:
|
||||||
checks:
|
checks:
|
||||||
@@ -39,5 +38,4 @@ severity:
|
|||||||
- gofmt
|
- gofmt
|
||||||
- goimports
|
- goimports
|
||||||
- intrange
|
- intrange
|
||||||
- usestdlibvars
|
|
||||||
severity: info
|
severity: info
|
||||||
|
|||||||
133
CMakeLists.txt
Normal file
133
CMakeLists.txt
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.21)
|
||||||
|
|
||||||
|
project(Ollama C CXX)
|
||||||
|
|
||||||
|
include(CheckLanguage)
|
||||||
|
|
||||||
|
find_package(Threads REQUIRED)
|
||||||
|
|
||||||
|
set(CMAKE_BUILD_TYPE Release)
|
||||||
|
set(BUILD_SHARED_LIBS ON)
|
||||||
|
|
||||||
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
|
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||||
|
|
||||||
|
set(GGML_BUILD ON)
|
||||||
|
set(GGML_SHARED ON)
|
||||||
|
set(GGML_CCACHE ON)
|
||||||
|
set(GGML_BACKEND_DL ON)
|
||||||
|
set(GGML_BACKEND_SHARED ON)
|
||||||
|
set(GGML_SCHED_MAX_COPIES 4)
|
||||||
|
|
||||||
|
set(GGML_LLAMAFILE ON)
|
||||||
|
set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128)
|
||||||
|
set(GGML_CUDA_GRAPHS ON)
|
||||||
|
set(GGML_CUDA_FA ON)
|
||||||
|
set(GGML_CUDA_COMPRESSION_MODE default)
|
||||||
|
|
||||||
|
if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
|
||||||
|
OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_SYSTEM_PROCESSOR MATCHES "arm|aarch64|ARM64|ARMv[0-9]+"))
|
||||||
|
set(GGML_CPU_ALL_VARIANTS ON)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
|
||||||
|
set(CMAKE_BUILD_RPATH "@loader_path")
|
||||||
|
set(CMAKE_INSTALL_RPATH "@loader_path")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
|
||||||
|
set(OLLAMA_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/lib/ollama)
|
||||||
|
|
||||||
|
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR})
|
||||||
|
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
|
||||||
|
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_RELEASE ${OLLAMA_BUILD_DIR})
|
||||||
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR})
|
||||||
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
|
||||||
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${OLLAMA_BUILD_DIR})
|
||||||
|
|
||||||
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
||||||
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include)
|
||||||
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
|
||||||
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
|
||||||
|
|
||||||
|
set(GGML_CPU ON)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
||||||
|
set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE)
|
||||||
|
|
||||||
|
get_target_property(CPU_VARIANTS ggml-cpu MANUALLY_ADDED_DEPENDENCIES)
|
||||||
|
if(NOT CPU_VARIANTS)
|
||||||
|
set(CPU_VARIANTS "ggml-cpu")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
install(TARGETS ggml-base ${CPU_VARIANTS}
|
||||||
|
RUNTIME_DEPENDENCIES
|
||||||
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
|
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CPU
|
||||||
|
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CPU
|
||||||
|
FRAMEWORK DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CPU
|
||||||
|
)
|
||||||
|
|
||||||
|
check_language(CUDA)
|
||||||
|
if(CMAKE_CUDA_COMPILER)
|
||||||
|
if(CMAKE_VERSION VERSION_GREATER_EQUAL "3.24" AND NOT CMAKE_CUDA_ARCHITECTURES)
|
||||||
|
set(CMAKE_CUDA_ARCHITECTURES "native")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
find_package(CUDAToolkit)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
||||||
|
set(OLLAMA_CUDA_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/cuda_v${CUDAToolkit_VERSION_MAJOR})
|
||||||
|
install(TARGETS ggml-cuda
|
||||||
|
RUNTIME_DEPENDENCIES
|
||||||
|
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_LIBRARY_DIR}
|
||||||
|
PRE_INCLUDE_REGEXES cublas cublasLt cudart
|
||||||
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
|
RUNTIME DESTINATION ${OLLAMA_CUDA_INSTALL_DIR} COMPONENT CUDA
|
||||||
|
LIBRARY DESTINATION ${OLLAMA_CUDA_INSTALL_DIR} COMPONENT CUDA
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX "^gfx(906|908|90a|1200|1201):xnack[+-]$"
|
||||||
|
CACHE STRING
|
||||||
|
"Regular expression describing AMDGPU_TARGETS not supported on Windows. Override to force building these targets. Default \"^gfx(906|908|90a|1200|1201):xnack[+-]$\"."
|
||||||
|
)
|
||||||
|
|
||||||
|
check_language(HIP)
|
||||||
|
if(CMAKE_HIP_COMPILER)
|
||||||
|
set(HIP_PLATFORM "amd")
|
||||||
|
|
||||||
|
find_package(hip REQUIRED)
|
||||||
|
if(NOT AMDGPU_TARGETS)
|
||||||
|
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012]|120[01])$")
|
||||||
|
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
|
||||||
|
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(AMDGPU_TARGETS)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
||||||
|
|
||||||
|
if (WIN32)
|
||||||
|
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
target_compile_definitions(ggml-hip PRIVATE GGML_HIP_NO_VMM)
|
||||||
|
|
||||||
|
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
|
||||||
|
install(TARGETS ggml-hip
|
||||||
|
RUNTIME_DEPENDENCIES
|
||||||
|
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
|
||||||
|
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf
|
||||||
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
|
POST_EXCLUDE_REGEXES "system32"
|
||||||
|
RUNTIME DESTINATION ${OLLAMA_HIP_INSTALL_DIR} COMPONENT HIP
|
||||||
|
LIBRARY DESTINATION ${OLLAMA_HIP_INSTALL_DIR} COMPONENT HIP
|
||||||
|
)
|
||||||
|
|
||||||
|
foreach(HIP_LIB_BIN_INSTALL_DIR IN ITEMS ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR})
|
||||||
|
if(EXISTS ${HIP_LIB_BIN_INSTALL_DIR}/rocblas)
|
||||||
|
install(DIRECTORY ${HIP_LIB_BIN_INSTALL_DIR}/rocblas DESTINATION ${OLLAMA_HIP_INSTALL_DIR} COMPONENT HIP)
|
||||||
|
break()
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
112
CMakePresets.json
Normal file
112
CMakePresets.json
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
{
|
||||||
|
"version": 3,
|
||||||
|
"configurePresets": [
|
||||||
|
{
|
||||||
|
"name": "Default",
|
||||||
|
"binaryDir": "${sourceDir}/build",
|
||||||
|
"installDir": "${sourceDir}/dist",
|
||||||
|
"cacheVariables": {
|
||||||
|
"CMAKE_BUILD_TYPE": "Release"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "CPU",
|
||||||
|
"inherits": [ "Default" ]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "CUDA",
|
||||||
|
"inherits": [ "Default" ]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "CUDA 11",
|
||||||
|
"inherits": [ "CUDA" ],
|
||||||
|
"cacheVariables": {
|
||||||
|
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;70;75;80;86",
|
||||||
|
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "CUDA 12",
|
||||||
|
"inherits": [ "CUDA" ],
|
||||||
|
"cacheVariables": {
|
||||||
|
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120",
|
||||||
|
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "JetPack 5",
|
||||||
|
"inherits": [ "CUDA" ],
|
||||||
|
"cacheVariables": {
|
||||||
|
"CMAKE_CUDA_ARCHITECTURES": "72;87"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "JetPack 6",
|
||||||
|
"inherits": [ "CUDA" ],
|
||||||
|
"cacheVariables": {
|
||||||
|
"CMAKE_CUDA_ARCHITECTURES": "87"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "ROCm",
|
||||||
|
"inherits": [ "Default" ],
|
||||||
|
"cacheVariables": {
|
||||||
|
"CMAKE_HIP_PLATFORM": "amd"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "ROCm 6",
|
||||||
|
"inherits": [ "ROCm" ],
|
||||||
|
"cacheVariables": {
|
||||||
|
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"buildPresets": [
|
||||||
|
{
|
||||||
|
"name": "Default",
|
||||||
|
"configurePreset": "Default",
|
||||||
|
"configuration": "Release"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "CPU",
|
||||||
|
"configurePreset": "Default",
|
||||||
|
"targets": [ "ggml-cpu" ]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "CUDA",
|
||||||
|
"configurePreset": "CUDA",
|
||||||
|
"targets": [ "ggml-cuda" ]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "CUDA 11",
|
||||||
|
"inherits": [ "CUDA" ],
|
||||||
|
"configurePreset": "CUDA 11"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "CUDA 12",
|
||||||
|
"inherits": [ "CUDA" ],
|
||||||
|
"configurePreset": "CUDA 12"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "JetPack 5",
|
||||||
|
"inherits": [ "CUDA" ],
|
||||||
|
"configurePreset": "JetPack 5"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "JetPack 6",
|
||||||
|
"inherits": [ "CUDA" ],
|
||||||
|
"configurePreset": "JetPack 6"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "ROCm",
|
||||||
|
"configurePreset": "ROCm",
|
||||||
|
"targets": [ "ggml-hip" ]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "ROCm 6",
|
||||||
|
"inherits": [ "ROCm" ],
|
||||||
|
"configurePreset": "ROCm 6"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -6,8 +6,6 @@ Thank you for your interest in contributing to Ollama! Here are a few guidelines
|
|||||||
|
|
||||||
See the [development documentation](./docs/development.md) for instructions on how to build and run Ollama locally.
|
See the [development documentation](./docs/development.md) for instructions on how to build and run Ollama locally.
|
||||||
|
|
||||||
## Pull requests
|
|
||||||
|
|
||||||
### Ideal issues
|
### Ideal issues
|
||||||
|
|
||||||
* [Bugs](https://github.com/ollama/ollama/issues?q=is%3Aissue+is%3Aopen+label%3Abug): issues where Ollama stops working or where it results in an unexpected error.
|
* [Bugs](https://github.com/ollama/ollama/issues?q=is%3Aissue+is%3Aopen+label%3Abug): issues where Ollama stops working or where it results in an unexpected error.
|
||||||
@@ -26,11 +24,64 @@ See the [development documentation](./docs/development.md) for instructions on h
|
|||||||
* Changes that add significant friction to the user experience
|
* Changes that add significant friction to the user experience
|
||||||
* Changes that create a large future maintenance burden for maintainers and contributors
|
* Changes that create a large future maintenance burden for maintainers and contributors
|
||||||
|
|
||||||
### Best practices
|
## Proposing a (non-trivial) change
|
||||||
|
|
||||||
* Commit messages: please leave both a title and a description in your commit messages. The title should be a short summary of the changes, with a leading word that explains the section of the code being changed (e.g. `api: fix parsing of prompt field`) . In the description, leave a short 2-3 sentences that explain more about the change and its impact.
|
> By "non-trivial", we mean a change that is not a bug fix or small
|
||||||
* Tests: please add test coverage to changes where possible.
|
> documentation update. If you are unsure, please ask us on our [Discord
|
||||||
* Minimize dependencies: avoid adding new dependencies unless absolutely necessary.
|
> server](https://discord.gg/ollama).
|
||||||
|
|
||||||
|
Before opening a non-trivial Pull Request, please open an issue to discuss the change and
|
||||||
|
get feedback from the maintainers. This helps us understand the context of the
|
||||||
|
change and how it fits into Ollama's roadmap and prevents us from duplicating
|
||||||
|
work or you from spending time on a change that we may not be able to accept.
|
||||||
|
|
||||||
|
Tips for proposals:
|
||||||
|
|
||||||
|
* Explain the problem you are trying to solve, not what you are trying to do.
|
||||||
|
* Explain why the change is important.
|
||||||
|
* Explain how the change will be used.
|
||||||
|
* Explain how the change will be tested.
|
||||||
|
|
||||||
|
Additionally, for bonus points: Provide draft documentation you would expect to
|
||||||
|
see if the change were accepted.
|
||||||
|
|
||||||
|
## Pull requests
|
||||||
|
|
||||||
|
**Commit messages**
|
||||||
|
|
||||||
|
The title should look like:
|
||||||
|
|
||||||
|
<package>: <short description>
|
||||||
|
|
||||||
|
The package is the most affected Go package. If the change does not affect Go
|
||||||
|
code, then use the directory name instead. Changes to a single well-known
|
||||||
|
file in the root directory may use the file name.
|
||||||
|
|
||||||
|
The short description should start with a lowercase letter and be a
|
||||||
|
continuation of the sentence:
|
||||||
|
|
||||||
|
"This changes Ollama to..."
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
llm/backend/mlx: support the llama architecture
|
||||||
|
CONTRIBUTING: provide clairity on good commit messages, and bad
|
||||||
|
|
||||||
|
Bad Examples:
|
||||||
|
|
||||||
|
feat: add more emoji
|
||||||
|
fix: was not using famous web framework
|
||||||
|
chore: generify code
|
||||||
|
|
||||||
|
**Tests**
|
||||||
|
|
||||||
|
Please include tests. Strive to test behavior, not implementation.
|
||||||
|
|
||||||
|
**New dependencies**
|
||||||
|
|
||||||
|
Dependencies should be added sparingly. If you are adding a new dependency,
|
||||||
|
please explain why it is necessary and what other ways you attempted that
|
||||||
|
did not work without it.
|
||||||
|
|
||||||
## Need help?
|
## Need help?
|
||||||
|
|
||||||
|
|||||||
284
Dockerfile
284
Dockerfile
@@ -1,201 +1,131 @@
|
|||||||
ARG GOLANG_VERSION=1.22.8
|
# vim: filetype=dockerfile
|
||||||
ARG CUDA_VERSION_11=11.3.1
|
|
||||||
ARG CUDA_VERSION_12=12.4.0
|
|
||||||
ARG ROCM_VERSION=6.1.2
|
|
||||||
ARG JETPACK_6=r36.2.0
|
|
||||||
ARG JETPACK_5=r35.4.1
|
|
||||||
|
|
||||||
### To create a local image for building linux binaries on mac or windows with efficient incremental builds
|
ARG FLAVOR=${TARGETARCH}
|
||||||
#
|
|
||||||
# docker build --platform linux/amd64 -t builder-amd64 -f Dockerfile --target unified-builder-amd64 .
|
|
||||||
# docker run --platform linux/amd64 --rm -it -v $(pwd):/go/src/github.com/ollama/ollama/ builder-amd64
|
|
||||||
#
|
|
||||||
### Then incremental builds will be much faster in this container
|
|
||||||
#
|
|
||||||
# make -j 10 dist
|
|
||||||
#
|
|
||||||
FROM --platform=linux/amd64 rocm/dev-centos-7:${ROCM_VERSION}-complete AS unified-builder-amd64
|
|
||||||
ARG GOLANG_VERSION
|
|
||||||
ARG CUDA_VERSION_11
|
|
||||||
ARG CUDA_VERSION_12
|
|
||||||
COPY ./scripts/rh_linux_deps.sh /
|
|
||||||
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:/usr/local/cuda/bin:$PATH
|
|
||||||
ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/cuda/lib64
|
|
||||||
RUN GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
|
|
||||||
RUN yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo && \
|
|
||||||
dnf clean all && \
|
|
||||||
dnf install -y \
|
|
||||||
zsh \
|
|
||||||
cuda-toolkit-$(echo ${CUDA_VERSION_11} | cut -f1-2 -d. | sed -e "s/\./-/g") \
|
|
||||||
cuda-toolkit-$(echo ${CUDA_VERSION_12} | cut -f1-2 -d. | sed -e "s/\./-/g")
|
|
||||||
# TODO intel oneapi goes here...
|
|
||||||
ENV GOARCH amd64
|
|
||||||
ENV CGO_ENABLED 1
|
|
||||||
WORKDIR /go/src/github.com/ollama/ollama/
|
|
||||||
ENTRYPOINT [ "zsh" ]
|
|
||||||
|
|
||||||
### To create a local image for building linux binaries on mac or linux/arm64 with efficient incremental builds
|
ARG ROCMVERSION=6.3.3
|
||||||
# Note: this does not contain jetson variants
|
ARG JETPACK5VERSION=r35.4.1
|
||||||
#
|
ARG JETPACK6VERSION=r36.4.0
|
||||||
# docker build --platform linux/arm64 -t builder-arm64 -f Dockerfile --target unified-builder-arm64 .
|
ARG CMAKEVERSION=3.31.2
|
||||||
# docker run --platform linux/arm64 --rm -it -v $(pwd):/go/src/github.com/ollama/ollama/ builder-arm64
|
|
||||||
#
|
|
||||||
FROM --platform=linux/arm64 rockylinux:8 AS unified-builder-arm64
|
|
||||||
ARG GOLANG_VERSION
|
|
||||||
ARG CUDA_VERSION_11
|
|
||||||
ARG CUDA_VERSION_12
|
|
||||||
COPY ./scripts/rh_linux_deps.sh /
|
|
||||||
RUN GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
|
|
||||||
RUN yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo && \
|
|
||||||
dnf config-manager --set-enabled appstream && \
|
|
||||||
dnf clean all && \
|
|
||||||
dnf install -y \
|
|
||||||
zsh \
|
|
||||||
cuda-toolkit-$(echo ${CUDA_VERSION_11} | cut -f1-2 -d. | sed -e "s/\./-/g") \
|
|
||||||
cuda-toolkit-$(echo ${CUDA_VERSION_12} | cut -f1-2 -d. | sed -e "s/\./-/g")
|
|
||||||
ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH:/usr/local/cuda/bin
|
|
||||||
ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/cuda/lib64
|
|
||||||
ENV LIBRARY_PATH=/usr/local/cuda/lib64/stubs:/opt/amdgpu/lib64
|
|
||||||
ENV GOARCH arm64
|
|
||||||
ENV CGO_ENABLED 1
|
|
||||||
WORKDIR /go/src/github.com/ollama/ollama/
|
|
||||||
ENTRYPOINT [ "zsh" ]
|
|
||||||
|
|
||||||
FROM --platform=linux/amd64 unified-builder-amd64 AS build-amd64
|
# CUDA v11 requires gcc v10. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
|
||||||
COPY . .
|
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
|
||||||
ARG OLLAMA_SKIP_CUDA_GENERATE
|
RUN yum install -y yum-utils \
|
||||||
ARG OLLAMA_SKIP_ROCM_GENERATE
|
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
|
||||||
ARG OLLAMA_FAST_BUILD
|
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
|
||||||
ARG VERSION
|
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \
|
||||||
ARG CUSTOM_CPU_FLAGS
|
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
||||||
|
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
|
||||||
|
|
||||||
|
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
|
||||||
|
# install epel-release for ccache
|
||||||
|
RUN yum install -y yum-utils epel-release \
|
||||||
|
&& dnf install -y clang ccache \
|
||||||
|
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
|
||||||
|
ENV CC=clang CXX=clang++
|
||||||
|
|
||||||
|
FROM base-${TARGETARCH} AS base
|
||||||
|
ARG CMAKEVERSION
|
||||||
|
RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||||
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
|
ENV LDFLAGS=-s
|
||||||
|
|
||||||
|
FROM base AS cpu
|
||||||
|
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
|
||||||
|
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
if grep "^flags" /proc/cpuinfo|grep avx>/dev/null; then \
|
cmake --preset 'CPU' \
|
||||||
make -j $(nproc) dist ; \
|
&& cmake --build --parallel --preset 'CPU' \
|
||||||
else \
|
&& cmake --install build --component CPU --strip --parallel 8
|
||||||
make -j 5 dist ; \
|
|
||||||
fi
|
|
||||||
RUN cd dist/linux-$GOARCH && \
|
|
||||||
tar -cf - . | pigz --best > ../ollama-linux-$GOARCH.tgz
|
|
||||||
RUN if [ -z ${OLLAMA_SKIP_ROCM_GENERATE} ] ; then \
|
|
||||||
cd dist/linux-$GOARCH-rocm && \
|
|
||||||
tar -cf - . | pigz --best > ../ollama-linux-$GOARCH-rocm.tgz ;\
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Jetsons need to be built in discrete stages
|
FROM base AS cuda-11
|
||||||
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK_5} AS runners-jetpack5-arm64
|
ARG CUDA11VERSION=11.3
|
||||||
ARG GOLANG_VERSION
|
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
||||||
RUN apt-get update && apt-get install -y git curl ccache && \
|
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
||||||
curl -s -L https://dl.google.com/go/go${GOLANG_VERSION}.linux-arm64.tar.gz | tar xz -C /usr/local && \
|
|
||||||
ln -s /usr/local/go/bin/go /usr/local/bin/go && \
|
|
||||||
ln -s /usr/local/go/bin/gofmt /usr/local/bin/gofmt && \
|
|
||||||
apt-get clean && rm -rf /var/lib/apt/lists/*
|
|
||||||
WORKDIR /go/src/github.com/ollama/ollama/
|
|
||||||
COPY . .
|
|
||||||
ARG CGO_CFLAGS
|
|
||||||
ENV GOARCH arm64
|
|
||||||
ARG VERSION
|
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
make -j 5 dist_cuda_v11 \
|
cmake --preset 'CUDA 11' \
|
||||||
CUDA_ARCHITECTURES="72;87" \
|
&& cmake --build --parallel --preset 'CUDA 11' \
|
||||||
GPU_RUNNER_VARIANT=_jetpack5 \
|
&& cmake --install build --component CUDA --strip --parallel 8
|
||||||
DIST_LIB_DIR=/go/src/github.com/ollama/ollama/dist/linux-arm64-jetpack5/lib/ollama \
|
|
||||||
DIST_GPU_RUNNER_DEPS_DIR=/go/src/github.com/ollama/ollama/dist/linux-arm64-jetpack5/lib/ollama/cuda_jetpack5
|
|
||||||
|
|
||||||
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK_6} AS runners-jetpack6-arm64
|
FROM base AS cuda-12
|
||||||
ARG GOLANG_VERSION
|
ARG CUDA12VERSION=12.8
|
||||||
RUN apt-get update && apt-get install -y git curl ccache && \
|
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
||||||
curl -s -L https://dl.google.com/go/go${GOLANG_VERSION}.linux-arm64.tar.gz | tar xz -C /usr/local && \
|
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
||||||
ln -s /usr/local/go/bin/go /usr/local/bin/go && \
|
|
||||||
ln -s /usr/local/go/bin/gofmt /usr/local/bin/gofmt && \
|
|
||||||
apt-get clean && rm -rf /var/lib/apt/lists/*
|
|
||||||
WORKDIR /go/src/github.com/ollama/ollama/
|
|
||||||
COPY . .
|
|
||||||
ARG CGO_CFLAGS
|
|
||||||
ENV GOARCH arm64
|
|
||||||
ARG VERSION
|
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
make -j 5 dist_cuda_v12 \
|
cmake --preset 'CUDA 12' \
|
||||||
CUDA_ARCHITECTURES="87" \
|
&& cmake --build --parallel --preset 'CUDA 12' \
|
||||||
GPU_RUNNER_VARIANT=_jetpack6 \
|
&& cmake --install build --component CUDA --strip --parallel 8
|
||||||
DIST_LIB_DIR=/go/src/github.com/ollama/ollama/dist/linux-arm64-jetpack6/lib/ollama \
|
|
||||||
DIST_GPU_RUNNER_DEPS_DIR=/go/src/github.com/ollama/ollama/dist/linux-arm64-jetpack6/lib/ollama/cuda_jetpack6
|
|
||||||
|
|
||||||
FROM --platform=linux/arm64 unified-builder-arm64 AS build-arm64
|
FROM base AS rocm-6
|
||||||
COPY . .
|
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
||||||
ARG OLLAMA_SKIP_CUDA_GENERATE
|
|
||||||
ARG OLLAMA_FAST_BUILD
|
|
||||||
ARG VERSION
|
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
make -j 5 dist
|
cmake --preset 'ROCm 6' \
|
||||||
COPY --from=runners-jetpack5-arm64 /go/src/github.com/ollama/ollama/dist/ dist/
|
&& cmake --build --parallel --preset 'ROCm 6' \
|
||||||
COPY --from=runners-jetpack6-arm64 /go/src/github.com/ollama/ollama/dist/ dist/
|
&& cmake --install build --component HIP --strip --parallel 8
|
||||||
RUN cd dist/linux-$GOARCH && \
|
|
||||||
tar -cf - . | pigz --best > ../ollama-linux-$GOARCH.tgz
|
|
||||||
RUN cd dist/linux-$GOARCH-jetpack5 && \
|
|
||||||
tar -cf - . | pigz --best > ../ollama-linux-$GOARCH-jetpack5.tgz
|
|
||||||
RUN cd dist/linux-$GOARCH-jetpack6 && \
|
|
||||||
tar -cf - . | pigz --best > ../ollama-linux-$GOARCH-jetpack6.tgz
|
|
||||||
|
|
||||||
FROM --platform=linux/amd64 scratch AS dist-amd64
|
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
|
||||||
COPY --from=build-amd64 /go/src/github.com/ollama/ollama/dist/ollama-linux-*.tgz /
|
ARG CMAKEVERSION
|
||||||
FROM --platform=linux/arm64 scratch AS dist-arm64
|
RUN apt-get update && apt-get install -y curl ccache \
|
||||||
COPY --from=build-arm64 /go/src/github.com/ollama/ollama/dist/ollama-linux-*.tgz /
|
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||||
FROM dist-$TARGETARCH AS dist
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
|
cmake --preset 'JetPack 5' \
|
||||||
|
&& cmake --build --parallel --preset 'JetPack 5' \
|
||||||
|
&& cmake --install build --component CUDA --strip --parallel 8
|
||||||
|
|
||||||
|
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
|
||||||
|
ARG CMAKEVERSION
|
||||||
|
RUN apt-get update && apt-get install -y curl ccache \
|
||||||
|
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||||
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
|
cmake --preset 'JetPack 6' \
|
||||||
|
&& cmake --build --parallel --preset 'JetPack 6' \
|
||||||
|
&& cmake --install build --component CUDA --strip --parallel 8
|
||||||
|
|
||||||
# For amd64 container images, filter out cuda/rocm to minimize size
|
FROM base AS build
|
||||||
FROM build-amd64 AS runners-cuda-amd64
|
WORKDIR /go/src/github.com/ollama/ollama
|
||||||
RUN rm -rf \
|
COPY go.mod go.sum .
|
||||||
./dist/linux-amd64/lib/ollama/libggml_hipblas.so \
|
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||||
./dist/linux-amd64/lib/ollama/runners/rocm*
|
ENV PATH=/usr/local/go/bin:$PATH
|
||||||
|
RUN go mod download
|
||||||
|
COPY . .
|
||||||
|
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||||
|
ENV CGO_ENABLED=1
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||||
|
go build -trimpath -buildmode=pie -o /bin/ollama .
|
||||||
|
|
||||||
FROM build-amd64 AS runners-rocm-amd64
|
FROM --platform=linux/amd64 scratch AS amd64
|
||||||
RUN rm -rf \
|
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
|
||||||
./dist/linux-amd64/lib/ollama/libggml_cuda*.so \
|
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
|
||||||
./dist/linux-amd64/lib/ollama/libcu*.so* \
|
|
||||||
./dist/linux-amd64/lib/ollama/runners/cuda*
|
|
||||||
|
|
||||||
FROM --platform=linux/amd64 ubuntu:22.04 AS runtime-amd64
|
FROM --platform=linux/arm64 scratch AS arm64
|
||||||
RUN apt-get update && \
|
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
|
||||||
apt-get install -y ca-certificates && \
|
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
|
||||||
apt-get clean && rm -rf /var/lib/apt/lists/*
|
COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_jetpack5
|
||||||
COPY --from=build-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64/bin/ /bin/
|
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_jetpack6
|
||||||
COPY --from=runners-cuda-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64/lib/ /lib/
|
|
||||||
|
|
||||||
FROM --platform=linux/arm64 ubuntu:22.04 AS runtime-arm64
|
FROM scratch AS rocm
|
||||||
RUN apt-get update && \
|
COPY --from=rocm-6 dist/lib/ollama/rocm /lib/ollama/rocm
|
||||||
apt-get install -y ca-certificates && \
|
|
||||||
apt-get clean && rm -rf /var/lib/apt/lists/*
|
|
||||||
COPY --from=build-arm64 /go/src/github.com/ollama/ollama/dist/linux-arm64/bin/ /bin/
|
|
||||||
COPY --from=build-arm64 /go/src/github.com/ollama/ollama/dist/linux-arm64/lib/ /lib/
|
|
||||||
COPY --from=runners-jetpack5-arm64 /go/src/github.com/ollama/ollama/dist/linux-arm64-jetpack5/lib/ /lib/
|
|
||||||
COPY --from=runners-jetpack6-arm64 /go/src/github.com/ollama/ollama/dist/linux-arm64-jetpack6/lib/ /lib/
|
|
||||||
|
|
||||||
|
FROM ${FLAVOR} AS archive
|
||||||
|
COPY --from=cpu dist/lib/ollama /lib/ollama
|
||||||
|
COPY --from=build /bin/ollama /bin/ollama
|
||||||
|
|
||||||
# ROCm libraries larger so we keep it distinct from the CPU/CUDA image
|
FROM ubuntu:20.04
|
||||||
FROM --platform=linux/amd64 ubuntu:22.04 AS runtime-rocm
|
RUN apt-get update \
|
||||||
# Frontload the rocm libraries which are large, and rarely change to increase chance of a common layer
|
&& apt-get install -y ca-certificates \
|
||||||
# across releases
|
&& apt-get clean \
|
||||||
COPY --from=build-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64-rocm/lib/ /lib/
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
RUN apt-get update && \
|
COPY --from=archive /bin /usr/bin
|
||||||
apt-get install -y ca-certificates && \
|
|
||||||
apt-get clean && rm -rf /var/lib/apt/lists/*
|
|
||||||
COPY --from=build-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64/bin/ /bin/
|
|
||||||
COPY --from=runners-rocm-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64/lib/ /lib/
|
|
||||||
|
|
||||||
EXPOSE 11434
|
|
||||||
ENV OLLAMA_HOST 0.0.0.0
|
|
||||||
|
|
||||||
ENTRYPOINT ["/bin/ollama"]
|
|
||||||
CMD ["serve"]
|
|
||||||
|
|
||||||
FROM runtime-$TARGETARCH
|
|
||||||
EXPOSE 11434
|
|
||||||
ENV OLLAMA_HOST 0.0.0.0
|
|
||||||
ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
||||||
|
COPY --from=archive /lib/ollama /usr/lib/ollama
|
||||||
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||||
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
||||||
ENV NVIDIA_VISIBLE_DEVICES=all
|
ENV NVIDIA_VISIBLE_DEVICES=all
|
||||||
|
ENV OLLAMA_HOST=0.0.0.0:11434
|
||||||
|
EXPOSE 11434
|
||||||
ENTRYPOINT ["/bin/ollama"]
|
ENTRYPOINT ["/bin/ollama"]
|
||||||
CMD ["serve"]
|
CMD ["serve"]
|
||||||
|
|||||||
103
Makefile
103
Makefile
@@ -1,103 +0,0 @@
|
|||||||
# top level makefile for Ollama
|
|
||||||
include make/common-defs.make
|
|
||||||
|
|
||||||
|
|
||||||
# Determine which if any GPU runners we should build
|
|
||||||
include make/cuda-v11-defs.make
|
|
||||||
include make/cuda-v12-defs.make
|
|
||||||
include make/rocm-defs.make
|
|
||||||
|
|
||||||
ifeq ($(CUSTOM_CPU_FLAGS),)
|
|
||||||
ifeq ($(ARCH),amd64)
|
|
||||||
RUNNER_TARGETS=cpu
|
|
||||||
endif
|
|
||||||
# Without CUSTOM_CPU_FLAGS we default to build both v11 and v12 if present
|
|
||||||
ifeq ($(OLLAMA_SKIP_CUDA_GENERATE),)
|
|
||||||
ifneq ($(CUDA_11_COMPILER),)
|
|
||||||
RUNNER_TARGETS += cuda_v11
|
|
||||||
endif
|
|
||||||
ifneq ($(CUDA_12_COMPILER),)
|
|
||||||
RUNNER_TARGETS += cuda_v12
|
|
||||||
endif
|
|
||||||
endif
|
|
||||||
else # CUSTOM_CPU_FLAGS is set, we'll build only the latest cuda version detected
|
|
||||||
ifneq ($(CUDA_12_COMPILER),)
|
|
||||||
RUNNER_TARGETS += cuda_v12
|
|
||||||
else ifneq ($(CUDA_11_COMPILER),)
|
|
||||||
RUNNER_TARGETS += cuda_v11
|
|
||||||
endif
|
|
||||||
endif
|
|
||||||
|
|
||||||
ifeq ($(OLLAMA_SKIP_ROCM_GENERATE),)
|
|
||||||
ifneq ($(HIP_COMPILER),)
|
|
||||||
RUNNER_TARGETS += rocm
|
|
||||||
endif
|
|
||||||
endif
|
|
||||||
|
|
||||||
|
|
||||||
all: runners exe
|
|
||||||
|
|
||||||
dist: $(addprefix dist_, $(RUNNER_TARGETS)) dist_exe
|
|
||||||
|
|
||||||
dist_%:
|
|
||||||
@$(MAKE) --no-print-directory -f make/Makefile.$* dist
|
|
||||||
|
|
||||||
runners: $(RUNNER_TARGETS)
|
|
||||||
|
|
||||||
$(RUNNER_TARGETS):
|
|
||||||
@$(MAKE) --no-print-directory -f make/Makefile.$@
|
|
||||||
|
|
||||||
exe dist_exe:
|
|
||||||
@$(MAKE) --no-print-directory -f make/Makefile.ollama $@
|
|
||||||
|
|
||||||
help-sync apply-patches create-patches sync sync-clean:
|
|
||||||
@$(MAKE) --no-print-directory -f make/Makefile.sync $@
|
|
||||||
|
|
||||||
test integration lint:
|
|
||||||
@$(MAKE) --no-print-directory -f make/Makefile.test $@
|
|
||||||
|
|
||||||
clean:
|
|
||||||
rm -rf $(BUILD_DIR) $(DIST_LIB_DIR) $(OLLAMA_EXE) $(DIST_OLLAMA_EXE)
|
|
||||||
go clean -cache
|
|
||||||
|
|
||||||
help:
|
|
||||||
@echo "The following make targets will help you build Ollama"
|
|
||||||
@echo ""
|
|
||||||
@echo " make all # (default target) Build Ollama llm subprocess runners, and the primary ollama executable"
|
|
||||||
@echo " make runners # Build Ollama llm subprocess runners; after you may use 'go build .' to build the primary ollama exectuable"
|
|
||||||
@echo " make <runner> # Build specific runners. Enabled: '$(RUNNER_TARGETS)'"
|
|
||||||
@echo " make dist # Build the runners and primary ollama executable for distribution"
|
|
||||||
@echo " make help-sync # Help information on vendor update targets"
|
|
||||||
@echo " make help-runners # Help information on runner targets"
|
|
||||||
@echo ""
|
|
||||||
@echo "The following make targets will help you test Ollama"
|
|
||||||
@echo ""
|
|
||||||
@echo " make test # Run unit tests"
|
|
||||||
@echo " make integration # Run integration tests. You must 'make all' first"
|
|
||||||
@echo " make lint # Run lint and style tests"
|
|
||||||
@echo ""
|
|
||||||
@echo "For more information see 'docs/development.md'"
|
|
||||||
@echo ""
|
|
||||||
|
|
||||||
|
|
||||||
help-runners:
|
|
||||||
@echo "The following runners will be built based on discovered GPU libraries: '$(RUNNER_TARGETS)'"
|
|
||||||
@echo ""
|
|
||||||
@echo "GPU Runner CPU Flags: '$(GPU_RUNNER_CPU_FLAGS)' (Override with CUSTOM_CPU_FLAGS)"
|
|
||||||
@echo ""
|
|
||||||
@echo "# CUDA_PATH sets the location where CUDA toolkits are present"
|
|
||||||
@echo "CUDA_PATH=$(CUDA_PATH)"
|
|
||||||
@echo " CUDA_11_PATH=$(CUDA_11_PATH)"
|
|
||||||
@echo " CUDA_11_COMPILER=$(CUDA_11_COMPILER)"
|
|
||||||
@echo " CUDA_12_PATH=$(CUDA_12_PATH)"
|
|
||||||
@echo " CUDA_12_COMPILER=$(CUDA_12_COMPILER)"
|
|
||||||
@echo ""
|
|
||||||
@echo "# HIP_PATH sets the location where the ROCm toolkit is present"
|
|
||||||
@echo "HIP_PATH=$(HIP_PATH)"
|
|
||||||
@echo " HIP_COMPILER=$(HIP_COMPILER)"
|
|
||||||
|
|
||||||
.PHONY: all exe dist help help-sync help-runners test integration lint runners clean $(RUNNER_TARGETS)
|
|
||||||
|
|
||||||
# Handy debugging for make variables
|
|
||||||
print-%:
|
|
||||||
@echo '$*=$($*)'
|
|
||||||
60
Makefile.sync
Normal file
60
Makefile.sync
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
UPSTREAM=https://github.com/ggerganov/llama.cpp.git
|
||||||
|
WORKDIR=llama/vendor
|
||||||
|
FETCH_HEAD=2016f07bd106c73699ecbaace80f55db5ed95dac
|
||||||
|
|
||||||
|
.PHONY: help
|
||||||
|
help:
|
||||||
|
@echo "Available targets:"
|
||||||
|
@echo " sync Sync with upstream repositories"
|
||||||
|
@echo " checkout Checkout upstream repository"
|
||||||
|
@echo " apply-patches Apply patches to local repository"
|
||||||
|
@echo " format-patches Format patches from local repository"
|
||||||
|
@echo " clean Clean local repository"
|
||||||
|
@echo
|
||||||
|
@echo "Example:"
|
||||||
|
@echo " make -f $(lastword $(MAKEFILE_LIST)) clean sync"
|
||||||
|
|
||||||
|
.PHONY: sync
|
||||||
|
sync: llama/build-info.cpp llama/llama.cpp ml/backend/ggml/ggml
|
||||||
|
|
||||||
|
.PHONY: llama/build-info.cpp
|
||||||
|
llama/build-info.cpp: llama/build-info.cpp.in
|
||||||
|
sed -e 's|@FETCH_HEAD@|$(FETCH_HEAD)|' $< > $@
|
||||||
|
|
||||||
|
.PHONY: llama/llama.cpp
|
||||||
|
llama/llama.cpp: llama/vendor/
|
||||||
|
rsync -arvzc -f "merge $@/.rsync-filter" $< $@
|
||||||
|
|
||||||
|
.PHONY: ml/backend/ggml/ggml
|
||||||
|
ml/backend/ggml/ggml: llama/vendor/ggml/
|
||||||
|
rsync -arvzc -f "merge $@/.rsync-filter" $< $@
|
||||||
|
|
||||||
|
PATCHES=$(wildcard llama/patches/*.patch)
|
||||||
|
|
||||||
|
.PHONY: apply-patches
|
||||||
|
.NOTPARALLEL:
|
||||||
|
apply-patches: $(addsuffix ed, $(PATCHES))
|
||||||
|
|
||||||
|
%.patched: %.patch
|
||||||
|
@if git -c user.name=nobody -c 'user.email=<>' -C $(WORKDIR) am -3 $(realpath $<); then touch $@; else git -C $(WORKDIR) am --abort; exit 1; fi
|
||||||
|
|
||||||
|
.PHONY: checkout
|
||||||
|
checkout: $(WORKDIR)
|
||||||
|
git -C $(WORKDIR) fetch
|
||||||
|
git -C $(WORKDIR) checkout -f $(FETCH_HEAD)
|
||||||
|
|
||||||
|
$(WORKDIR):
|
||||||
|
git clone $(UPSTREAM) $(WORKDIR)
|
||||||
|
|
||||||
|
.PHONE: format-patches
|
||||||
|
format-patches: llama/patches
|
||||||
|
git -C $(WORKDIR) format-patch \
|
||||||
|
--no-signature \
|
||||||
|
--no-numbered \
|
||||||
|
--zero-commit \
|
||||||
|
-o $(realpath $<) \
|
||||||
|
$(FETCH_HEAD)
|
||||||
|
|
||||||
|
.PHONE: clean
|
||||||
|
clean: checkout
|
||||||
|
$(RM) $(addsuffix ed, $(PATCHES))
|
||||||
110
README.md
110
README.md
@@ -1,5 +1,5 @@
|
|||||||
<div align="center">
|
<div align="center">
|
||||||
<a href="https://ollama.com" />
|
<a href="https://ollama.com">
|
||||||
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
@@ -18,7 +18,7 @@ Get up and running with large language models.
|
|||||||
|
|
||||||
### Linux
|
### Linux
|
||||||
|
|
||||||
```
|
```shell
|
||||||
curl -fsSL https://ollama.com/install.sh | sh
|
curl -fsSL https://ollama.com/install.sh | sh
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -42,7 +42,7 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
|
|||||||
|
|
||||||
To run and chat with [Llama 3.2](https://ollama.com/library/llama3.2):
|
To run and chat with [Llama 3.2](https://ollama.com/library/llama3.2):
|
||||||
|
|
||||||
```
|
```shell
|
||||||
ollama run llama3.2
|
ollama run llama3.2
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -54,6 +54,13 @@ Here are some example models that can be downloaded:
|
|||||||
|
|
||||||
| Model | Parameters | Size | Download |
|
| Model | Parameters | Size | Download |
|
||||||
| ------------------ | ---------- | ----- | -------------------------------- |
|
| ------------------ | ---------- | ----- | -------------------------------- |
|
||||||
|
| Gemma 3 | 1B | 815MB | `ollama run gemma3:1b` |
|
||||||
|
| Gemma 3 | 4B | 3.3GB | `ollama run gemma3` |
|
||||||
|
| Gemma 3 | 12B | 8.1GB | `ollama run gemma3:12b` |
|
||||||
|
| Gemma 3 | 27B | 17GB | `ollama run gemma3:27b` |
|
||||||
|
| QwQ | 32B | 20GB | `ollama run qwq` |
|
||||||
|
| DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` |
|
||||||
|
| DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` |
|
||||||
| Llama 3.3 | 70B | 43GB | `ollama run llama3.3` |
|
| Llama 3.3 | 70B | 43GB | `ollama run llama3.3` |
|
||||||
| Llama 3.2 | 3B | 2.0GB | `ollama run llama3.2` |
|
| Llama 3.2 | 3B | 2.0GB | `ollama run llama3.2` |
|
||||||
| Llama 3.2 | 1B | 1.3GB | `ollama run llama3.2:1b` |
|
| Llama 3.2 | 1B | 1.3GB | `ollama run llama3.2:1b` |
|
||||||
@@ -62,10 +69,7 @@ Here are some example models that can be downloaded:
|
|||||||
| Llama 3.1 | 8B | 4.7GB | `ollama run llama3.1` |
|
| Llama 3.1 | 8B | 4.7GB | `ollama run llama3.1` |
|
||||||
| Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` |
|
| Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` |
|
||||||
| Phi 4 | 14B | 9.1GB | `ollama run phi4` |
|
| Phi 4 | 14B | 9.1GB | `ollama run phi4` |
|
||||||
| Phi 3 Mini | 3.8B | 2.3GB | `ollama run phi3` |
|
| Phi 4 Mini | 3.8B | 2.5GB | `ollama run phi4-mini` |
|
||||||
| Gemma 2 | 2B | 1.6GB | `ollama run gemma2:2b` |
|
|
||||||
| Gemma 2 | 9B | 5.5GB | `ollama run gemma2` |
|
|
||||||
| Gemma 2 | 27B | 16GB | `ollama run gemma2:27b` |
|
|
||||||
| Mistral | 7B | 4.1GB | `ollama run mistral` |
|
| Mistral | 7B | 4.1GB | `ollama run mistral` |
|
||||||
| Moondream 2 | 1.4B | 829MB | `ollama run moondream` |
|
| Moondream 2 | 1.4B | 829MB | `ollama run moondream` |
|
||||||
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
|
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
|
||||||
@@ -73,7 +77,7 @@ Here are some example models that can be downloaded:
|
|||||||
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
||||||
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
||||||
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
||||||
| Solar | 10.7B | 6.1GB | `ollama run solar` |
|
| Granite-3.2 | 8B | 4.9GB | `ollama run granite3.2` |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
||||||
@@ -92,13 +96,13 @@ Ollama supports importing GGUF models in the Modelfile:
|
|||||||
|
|
||||||
2. Create the model in Ollama
|
2. Create the model in Ollama
|
||||||
|
|
||||||
```
|
```shell
|
||||||
ollama create example -f Modelfile
|
ollama create example -f Modelfile
|
||||||
```
|
```
|
||||||
|
|
||||||
3. Run the model
|
3. Run the model
|
||||||
|
|
||||||
```
|
```shell
|
||||||
ollama run example
|
ollama run example
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -110,7 +114,7 @@ See the [guide](docs/import.md) on importing models for more information.
|
|||||||
|
|
||||||
Models from the Ollama library can be customized with a prompt. For example, to customize the `llama3.2` model:
|
Models from the Ollama library can be customized with a prompt. For example, to customize the `llama3.2` model:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
ollama pull llama3.2
|
ollama pull llama3.2
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -145,13 +149,13 @@ For more information on working with a Modelfile, see the [Modelfile](docs/model
|
|||||||
|
|
||||||
`ollama create` is used to create a model from a Modelfile.
|
`ollama create` is used to create a model from a Modelfile.
|
||||||
|
|
||||||
```
|
```shell
|
||||||
ollama create mymodel -f ./Modelfile
|
ollama create mymodel -f ./Modelfile
|
||||||
```
|
```
|
||||||
|
|
||||||
### Pull a model
|
### Pull a model
|
||||||
|
|
||||||
```
|
```shell
|
||||||
ollama pull llama3.2
|
ollama pull llama3.2
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -159,13 +163,13 @@ ollama pull llama3.2
|
|||||||
|
|
||||||
### Remove a model
|
### Remove a model
|
||||||
|
|
||||||
```
|
```shell
|
||||||
ollama rm llama3.2
|
ollama rm llama3.2
|
||||||
```
|
```
|
||||||
|
|
||||||
### Copy a model
|
### Copy a model
|
||||||
|
|
||||||
```
|
```shell
|
||||||
ollama cp llama3.2 my-model
|
ollama cp llama3.2 my-model
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -184,37 +188,39 @@ I'm a basic program that prints the famous "Hello, world!" message to the consol
|
|||||||
|
|
||||||
```
|
```
|
||||||
ollama run llava "What's in this image? /Users/jmorgan/Desktop/smile.png"
|
ollama run llava "What's in this image? /Users/jmorgan/Desktop/smile.png"
|
||||||
The image features a yellow smiley face, which is likely the central focus of the picture.
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> **Output**: The image features a yellow smiley face, which is likely the central focus of the picture.
|
||||||
|
|
||||||
### Pass the prompt as an argument
|
### Pass the prompt as an argument
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama run llama3.2 "Summarize this file: $(cat README.md)"
|
||||||
```
|
```
|
||||||
$ ollama run llama3.2 "Summarize this file: $(cat README.md)"
|
|
||||||
Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
|
> **Output**: Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
|
||||||
```
|
|
||||||
|
|
||||||
### Show model information
|
### Show model information
|
||||||
|
|
||||||
```
|
```shell
|
||||||
ollama show llama3.2
|
ollama show llama3.2
|
||||||
```
|
```
|
||||||
|
|
||||||
### List models on your computer
|
### List models on your computer
|
||||||
|
|
||||||
```
|
```shell
|
||||||
ollama list
|
ollama list
|
||||||
```
|
```
|
||||||
|
|
||||||
### List which models are currently loaded
|
### List which models are currently loaded
|
||||||
|
|
||||||
```
|
```shell
|
||||||
ollama ps
|
ollama ps
|
||||||
```
|
```
|
||||||
|
|
||||||
### Stop a model which is currently running
|
### Stop a model which is currently running
|
||||||
|
|
||||||
```
|
```shell
|
||||||
ollama stop llama3.2
|
ollama stop llama3.2
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -230,13 +236,13 @@ See the [developer guide](https://github.com/ollama/ollama/blob/main/docs/develo
|
|||||||
|
|
||||||
Next, start the server:
|
Next, start the server:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
./ollama serve
|
./ollama serve
|
||||||
```
|
```
|
||||||
|
|
||||||
Finally, in a separate shell, run a model:
|
Finally, in a separate shell, run a model:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
./ollama run llama3.2
|
./ollama run llama3.2
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -246,7 +252,7 @@ Ollama has a REST API for running and managing models.
|
|||||||
|
|
||||||
### Generate a response
|
### Generate a response
|
||||||
|
|
||||||
```
|
```shell
|
||||||
curl http://localhost:11434/api/generate -d '{
|
curl http://localhost:11434/api/generate -d '{
|
||||||
"model": "llama3.2",
|
"model": "llama3.2",
|
||||||
"prompt":"Why is the sky blue?"
|
"prompt":"Why is the sky blue?"
|
||||||
@@ -255,7 +261,7 @@ curl http://localhost:11434/api/generate -d '{
|
|||||||
|
|
||||||
### Chat with a model
|
### Chat with a model
|
||||||
|
|
||||||
```
|
```shell
|
||||||
curl http://localhost:11434/api/chat -d '{
|
curl http://localhost:11434/api/chat -d '{
|
||||||
"model": "llama3.2",
|
"model": "llama3.2",
|
||||||
"messages": [
|
"messages": [
|
||||||
@@ -271,6 +277,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
### Web & Desktop
|
### Web & Desktop
|
||||||
|
|
||||||
- [Open WebUI](https://github.com/open-webui/open-webui)
|
- [Open WebUI](https://github.com/open-webui/open-webui)
|
||||||
|
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
|
||||||
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
||||||
- [Hollama](https://github.com/fmaclen/hollama)
|
- [Hollama](https://github.com/fmaclen/hollama)
|
||||||
- [Lollms-Webui](https://github.com/ParisNeo/lollms-webui)
|
- [Lollms-Webui](https://github.com/ParisNeo/lollms-webui)
|
||||||
@@ -278,12 +285,13 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
||||||
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
||||||
- [Saddle](https://github.com/jikkuatwork/saddle)
|
- [Saddle](https://github.com/jikkuatwork/saddle)
|
||||||
|
- [TagSpaces](https://www.tagspaces.org) (A platform for file based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
|
||||||
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
||||||
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
|
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
|
||||||
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
|
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
|
||||||
- [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
|
- [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
|
||||||
- [Ollamac](https://github.com/kevinhermawan/Ollamac)
|
- [Ollamac](https://github.com/kevinhermawan/Ollamac)
|
||||||
- [big-AGI](https://github.com/enricoros/big-AGI/blob/main/docs/config-local-ollama.md)
|
- [big-AGI](https://github.com/enricoros/big-AGI)
|
||||||
- [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
|
- [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
|
||||||
- [Amica](https://github.com/semperai/amica)
|
- [Amica](https://github.com/semperai/amica)
|
||||||
- [chatd](https://github.com/BruceMacD/chatd)
|
- [chatd](https://github.com/BruceMacD/chatd)
|
||||||
@@ -317,6 +325,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama)
|
- [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama)
|
||||||
- [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models)
|
- [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models)
|
||||||
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
|
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
|
||||||
|
- [Casibase](https://casibase.org) (An open source AI knowledge base and dialogue system combining the latest RAG, SSO, ollama support and multiple large language models.)
|
||||||
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
|
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
|
||||||
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
|
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
|
||||||
- [Shinkai Desktop](https://github.com/dcSpark/shinkai-apps) (Two click install Local AI using Ollama + Files + RAG)
|
- [Shinkai Desktop](https://github.com/dcSpark/shinkai-apps) (Two click install Local AI using Ollama + Files + RAG)
|
||||||
@@ -339,7 +348,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
|
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
|
||||||
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot and Ollama4j
|
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot and Ollama4j
|
||||||
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
|
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
|
||||||
- [Claude Dev](https://github.com/saoudrizwan/claude-dev) - VSCode extension for multi-file/whole-repo coding
|
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VSCode extension for multi-file/whole-repo coding
|
||||||
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
|
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
|
||||||
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
|
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
|
||||||
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
|
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
|
||||||
@@ -353,6 +362,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
|
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
|
||||||
- [Promptery](https://github.com/promptery/promptery) (desktop client for Ollama.)
|
- [Promptery](https://github.com/promptery/promptery) (desktop client for Ollama.)
|
||||||
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
||||||
|
- [chat-ollama](https://github.com/annilq/chat-ollama) (a React Native client for Ollama)
|
||||||
- [SpaceLlama](https://github.com/tcsenpai/spacellama) (Firefox and Chrome extension to quickly summarize web pages with ollama in a sidebar)
|
- [SpaceLlama](https://github.com/tcsenpai/spacellama) (Firefox and Chrome extension to quickly summarize web pages with ollama in a sidebar)
|
||||||
- [YouLama](https://github.com/tcsenpai/youlama) (Webapp to quickly summarize any YouTube video, supporting Invidious as well)
|
- [YouLama](https://github.com/tcsenpai/youlama) (Webapp to quickly summarize any YouTube video, supporting Invidious as well)
|
||||||
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
|
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
|
||||||
@@ -369,6 +379,26 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Minima](https://github.com/dmayboroda/minima) (RAG with on-premises or fully local workflow)
|
- [Minima](https://github.com/dmayboroda/minima) (RAG with on-premises or fully local workflow)
|
||||||
- [aidful-ollama-model-delete](https://github.com/AidfulAI/aidful-ollama-model-delete) (User interface for simplified model cleanup)
|
- [aidful-ollama-model-delete](https://github.com/AidfulAI/aidful-ollama-model-delete) (User interface for simplified model cleanup)
|
||||||
- [Perplexica](https://github.com/ItzCrazyKns/Perplexica) (An AI-powered search engine & an open-source alternative to Perplexity AI)
|
- [Perplexica](https://github.com/ItzCrazyKns/Perplexica) (An AI-powered search engine & an open-source alternative to Perplexity AI)
|
||||||
|
- [Ollama Chat WebUI for Docker ](https://github.com/oslook/ollama-webui) (Support for local docker deployment, lightweight ollama webui)
|
||||||
|
- [AI Toolkit for Visual Studio Code](https://aka.ms/ai-tooklit/ollama-docs) (Microsoft-official VSCode extension to chat, test, evaluate models with Ollama support, and use them in your AI applications.)
|
||||||
|
- [MinimalNextOllamaChat](https://github.com/anilkay/MinimalNextOllamaChat) (Minimal Web UI for Chat and Model Control)
|
||||||
|
- [Chipper](https://github.com/TilmanGriesel/chipper) AI interface for tinkerers (Ollama, Haystack RAG, Python)
|
||||||
|
- [ChibiChat](https://github.com/CosmicEventHorizon/ChibiChat) (Kotlin-based Android app to chat with Ollama and Koboldcpp API endpoints)
|
||||||
|
- [LocalLLM](https://github.com/qusaismael/localllm) (Minimal Web-App to run ollama models on it with a GUI)
|
||||||
|
- [Ollamazing](https://github.com/buiducnhat/ollamazing) (Web extension to run Ollama models)
|
||||||
|
- [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivent endpoint with Ollama support for running locally)
|
||||||
|
- [AntSK](https://github.com/AIDotNet/AntSK) (Out-of-the-box & Adaptable RAG Chatbot)
|
||||||
|
- [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot)
|
||||||
|
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
|
||||||
|
- [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms)
|
||||||
|
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
|
||||||
|
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
|
||||||
|
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
||||||
|
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
|
||||||
|
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
|
||||||
|
- [Ollamb](https://github.com/hengkysteen/ollamb) (Simple yet rich in features, cross-platform built with Flutter and designed for Ollama. Try the [web demo](https://hengkysteen.github.io/demo/ollamb/).)
|
||||||
|
- [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama)
|
||||||
|
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
|
||||||
|
|
||||||
### Cloud
|
### Cloud
|
||||||
|
|
||||||
@@ -408,10 +438,14 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [SwollamaCLI](https://github.com/marcusziade/Swollama) bundled with the Swollama Swift package. [Demo](https://github.com/marcusziade/Swollama?tab=readme-ov-file#cli-usage)
|
- [SwollamaCLI](https://github.com/marcusziade/Swollama) bundled with the Swollama Swift package. [Demo](https://github.com/marcusziade/Swollama?tab=readme-ov-file#cli-usage)
|
||||||
- [aichat](https://github.com/sigoden/aichat) All-in-one LLM CLI tool featuring Shell Assistant, Chat-REPL, RAG, AI tools & agents, with access to OpenAI, Claude, Gemini, Ollama, Groq, and more.
|
- [aichat](https://github.com/sigoden/aichat) All-in-one LLM CLI tool featuring Shell Assistant, Chat-REPL, RAG, AI tools & agents, with access to OpenAI, Claude, Gemini, Ollama, Groq, and more.
|
||||||
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
|
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
|
||||||
|
- [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis.
|
||||||
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
|
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
|
||||||
|
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull and download models from Ollama Registry in your terminal.
|
||||||
|
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
|
||||||
|
|
||||||
### Apple Vision Pro
|
### Apple Vision Pro
|
||||||
|
|
||||||
|
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Cross-platform AI chat app supporting Apple Vision Pro via "Designed for iPad")
|
||||||
- [Enchanted](https://github.com/AugustDev/enchanted)
|
- [Enchanted](https://github.com/AugustDev/enchanted)
|
||||||
|
|
||||||
### Database
|
### Database
|
||||||
@@ -426,9 +460,10 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
|
|
||||||
- [Pacman](https://archlinux.org/packages/extra/x86_64/ollama/)
|
- [Pacman](https://archlinux.org/packages/extra/x86_64/ollama/)
|
||||||
- [Gentoo](https://github.com/gentoo/guru/tree/master/app-misc/ollama)
|
- [Gentoo](https://github.com/gentoo/guru/tree/master/app-misc/ollama)
|
||||||
|
- [Homebrew](https://formulae.brew.sh/formula/ollama)
|
||||||
- [Helm Chart](https://artifacthub.io/packages/helm/ollama-helm/ollama)
|
- [Helm Chart](https://artifacthub.io/packages/helm/ollama-helm/ollama)
|
||||||
- [Guix channel](https://codeberg.org/tusharhero/ollama-guix)
|
- [Guix channel](https://codeberg.org/tusharhero/ollama-guix)
|
||||||
- [Nix package](https://search.nixos.org/packages?channel=24.05&show=ollama&from=0&size=50&sort=relevance&type=packages&query=ollama)
|
- [Nix package](https://search.nixos.org/packages?show=ollama&from=0&size=50&sort=relevance&type=packages&query=ollama)
|
||||||
- [Flox](https://flox.dev/blog/ollama-part-one)
|
- [Flox](https://flox.dev/blog/ollama-part-one)
|
||||||
|
|
||||||
### Libraries
|
### Libraries
|
||||||
@@ -481,13 +516,21 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [GoLamify](https://github.com/prasad89/golamify)
|
- [GoLamify](https://github.com/prasad89/golamify)
|
||||||
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
|
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
|
||||||
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in unified API)
|
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in unified API)
|
||||||
|
- [LlmTornado](https://github.com/lofcz/llmtornado) (C# library providing a unified interface for major FOSS & Commercial inference APIs)
|
||||||
|
- [Ollama for Zig](https://github.com/dravenk/ollama-zig)
|
||||||
|
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
|
||||||
|
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
|
||||||
|
- [Ollama for D](https://github.com/kassane/ollama-d)
|
||||||
|
|
||||||
### Mobile
|
### Mobile
|
||||||
|
|
||||||
|
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Lightning-fast Cross-platform AI chat app with native UI for Android, iOS and iPad)
|
||||||
- [Enchanted](https://github.com/AugustDev/enchanted)
|
- [Enchanted](https://github.com/AugustDev/enchanted)
|
||||||
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid)
|
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid)
|
||||||
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
||||||
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
|
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
|
||||||
|
- [Ollama Android Chat](https://github.com/sunshine0523/OllamaServer) (No need for Termux, start the Ollama service with one click on an Android device)
|
||||||
|
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
||||||
|
|
||||||
### Extensions & Plugins
|
### Extensions & Plugins
|
||||||
|
|
||||||
@@ -531,13 +574,18 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [TextCraft](https://github.com/suncloudsmoon/TextCraft) (Copilot in Word alternative using Ollama)
|
- [TextCraft](https://github.com/suncloudsmoon/TextCraft) (Copilot in Word alternative using Ollama)
|
||||||
- [Alfred Ollama](https://github.com/zeitlings/alfred-ollama) (Alfred Workflow)
|
- [Alfred Ollama](https://github.com/zeitlings/alfred-ollama) (Alfred Workflow)
|
||||||
- [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language
|
- [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language
|
||||||
|
- [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai)
|
||||||
|
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
|
||||||
|
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
|
||||||
|
|
||||||
### Supported backends
|
### Supported backends
|
||||||
|
|
||||||
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
|
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
|
||||||
|
|
||||||
### Observability
|
### Observability
|
||||||
|
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native intergration to Ollama.
|
||||||
|
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
|
||||||
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
||||||
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.
|
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.
|
||||||
- [Langfuse](https://langfuse.com/docs/integrations/ollama) is an open source LLM observability platform that enables teams to collaboratively monitor, evaluate and debug AI applications.
|
- [Langfuse](https://langfuse.com/docs/integrations/ollama) is an open source LLM observability platform that enables teams to collaboratively monitor, evaluate and debug AI applications.
|
||||||
|
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
|
||||||
|
|||||||
@@ -10,7 +10,7 @@
|
|||||||
// repository].
|
// repository].
|
||||||
//
|
//
|
||||||
// [the API documentation]: https://github.com/ollama/ollama/blob/main/docs/api.md
|
// [the API documentation]: https://github.com/ollama/ollama/blob/main/docs/api.md
|
||||||
// [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/examples
|
// [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/api/examples
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -132,7 +132,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
|||||||
const maxBufferSize = 512 * format.KiloByte
|
const maxBufferSize = 512 * format.KiloByte
|
||||||
|
|
||||||
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
||||||
var buf *bytes.Buffer
|
var buf io.Reader
|
||||||
if data != nil {
|
if data != nil {
|
||||||
bts, err := json.Marshal(data)
|
bts, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,6 +1,13 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -43,3 +50,206 @@ func TestClientFromEnvironment(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// testError represents an internal error type with status code and message
|
||||||
|
// this is used since the error response from the server is not a standard error struct
|
||||||
|
type testError struct {
|
||||||
|
message string
|
||||||
|
statusCode int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e testError) Error() string {
|
||||||
|
return e.message
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientStream(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
responses []any
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "immediate error response",
|
||||||
|
responses: []any{
|
||||||
|
testError{
|
||||||
|
message: "test error message",
|
||||||
|
statusCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "test error message",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error after successful chunks, ok response",
|
||||||
|
responses: []any{
|
||||||
|
ChatResponse{Message: Message{Content: "partial response 1"}},
|
||||||
|
ChatResponse{Message: Message{Content: "partial response 2"}},
|
||||||
|
testError{
|
||||||
|
message: "mid-stream error",
|
||||||
|
statusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "mid-stream error",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "successful stream completion",
|
||||||
|
responses: []any{
|
||||||
|
ChatResponse{Message: Message{Content: "chunk 1"}},
|
||||||
|
ChatResponse{Message: Message{Content: "chunk 2"}},
|
||||||
|
ChatResponse{
|
||||||
|
Message: Message{Content: "final chunk"},
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "stop",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
flusher, ok := w.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected http.Flusher")
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||||
|
|
||||||
|
for _, resp := range tc.responses {
|
||||||
|
if errResp, ok := resp.(testError); ok {
|
||||||
|
w.WriteHeader(errResp.statusCode)
|
||||||
|
err := json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"error": errResp.message,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("failed to encode error response:", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||||
|
t.Fatalf("failed to encode response: %v", err)
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
|
||||||
|
|
||||||
|
var receivedChunks []ChatResponse
|
||||||
|
err := client.stream(context.Background(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error {
|
||||||
|
var resp ChatResponse
|
||||||
|
if err := json.Unmarshal(chunk, &resp); err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal chunk: %w", err)
|
||||||
|
}
|
||||||
|
receivedChunks = append(receivedChunks, resp)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if tc.wantErr != "" {
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error but got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), tc.wantErr) {
|
||||||
|
t.Errorf("expected error containing %q, got %v", tc.wantErr, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientDo(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
response any
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "immediate error response",
|
||||||
|
response: testError{
|
||||||
|
message: "test error message",
|
||||||
|
statusCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
wantErr: "test error message",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "server error response",
|
||||||
|
response: testError{
|
||||||
|
message: "internal error",
|
||||||
|
statusCode: http.StatusInternalServerError,
|
||||||
|
},
|
||||||
|
wantErr: "internal error",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "successful response",
|
||||||
|
response: struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Success bool `json:"success"`
|
||||||
|
}{
|
||||||
|
ID: "msg_123",
|
||||||
|
Success: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if errResp, ok := tc.response.(testError); ok {
|
||||||
|
w.WriteHeader(errResp.statusCode)
|
||||||
|
err := json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"error": errResp.message,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("failed to encode error response:", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if err := json.NewEncoder(w).Encode(tc.response); err != nil {
|
||||||
|
t.Fatalf("failed to encode response: %v", err)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Success bool `json:"success"`
|
||||||
|
}
|
||||||
|
err := client.do(context.Background(), http.MethodPost, "/v1/messages", nil, &resp)
|
||||||
|
|
||||||
|
if tc.wantErr != "" {
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("got nil, want error %q", tc.wantErr)
|
||||||
|
}
|
||||||
|
if err.Error() != tc.wantErr {
|
||||||
|
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("got error %q, want nil", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if expectedResp, ok := tc.response.(struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Success bool `json:"success"`
|
||||||
|
}); ok {
|
||||||
|
if resp.ID != expectedResp.ID {
|
||||||
|
t.Errorf("response ID mismatch: got %q, want %q", resp.ID, expectedResp.ID)
|
||||||
|
}
|
||||||
|
if resp.Success != expectedResp.Success {
|
||||||
|
t.Errorf("response Success mismatch: got %v, want %v", resp.Success, expectedResp.Success)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,9 +2,10 @@
|
|||||||
|
|
||||||
Run the examples in this directory with:
|
Run the examples in this directory with:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
go run example_name/main.go
|
go run example_name/main.go
|
||||||
```
|
```
|
||||||
|
|
||||||
## Chat - Chat with a model
|
## Chat - Chat with a model
|
||||||
- [chat/main.go](chat/main.go)
|
- [chat/main.go](chat/main.go)
|
||||||
|
|
||||||
|
|||||||
94
api/types.go
94
api/types.go
@@ -10,6 +10,9 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StatusError is an error with an HTTP status code and message.
|
// StatusError is an error with an HTTP status code and message.
|
||||||
@@ -73,13 +76,13 @@ type GenerateRequest struct {
|
|||||||
// this request.
|
// this request.
|
||||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||||
|
|
||||||
// Images is an optional list of base64-encoded images accompanying this
|
// Images is an optional list of raw image bytes accompanying this
|
||||||
// request, for multimodal models.
|
// request, for multimodal models.
|
||||||
Images []ImageData `json:"images,omitempty"`
|
Images []ImageData `json:"images,omitempty"`
|
||||||
|
|
||||||
// Options lists model-specific options. For example, temperature can be
|
// Options lists model-specific options. For example, temperature can be
|
||||||
// set through this field, if the model supports it.
|
// set through this field, if the model supports it.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]any `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatRequest describes a request sent by [Client.Chat].
|
// ChatRequest describes a request sent by [Client.Chat].
|
||||||
@@ -104,7 +107,7 @@ type ChatRequest struct {
|
|||||||
Tools `json:"tools,omitempty"`
|
Tools `json:"tools,omitempty"`
|
||||||
|
|
||||||
// Options lists model-specific options.
|
// Options lists model-specific options.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]any `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Tools []Tool
|
type Tools []Tool
|
||||||
@@ -160,19 +163,65 @@ func (t *ToolCallFunctionArguments) String() string {
|
|||||||
|
|
||||||
type Tool struct {
|
type Tool struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
Items any `json:"items,omitempty"`
|
||||||
Function ToolFunction `json:"function"`
|
Function ToolFunction `json:"function"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PropertyType can be either a string or an array of strings
|
||||||
|
type PropertyType []string
|
||||||
|
|
||||||
|
// UnmarshalJSON implements the json.Unmarshaler interface
|
||||||
|
func (pt *PropertyType) UnmarshalJSON(data []byte) error {
|
||||||
|
// Try to unmarshal as a string first
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(data, &s); err == nil {
|
||||||
|
*pt = []string{s}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If that fails, try to unmarshal as an array of strings
|
||||||
|
var a []string
|
||||||
|
if err := json.Unmarshal(data, &a); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*pt = a
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements the json.Marshaler interface
|
||||||
|
func (pt PropertyType) MarshalJSON() ([]byte, error) {
|
||||||
|
if len(pt) == 1 {
|
||||||
|
// If there's only one type, marshal as a string
|
||||||
|
return json.Marshal(pt[0])
|
||||||
|
}
|
||||||
|
// Otherwise marshal as an array
|
||||||
|
return json.Marshal([]string(pt))
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns a string representation of the PropertyType
|
||||||
|
func (pt PropertyType) String() string {
|
||||||
|
if len(pt) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if len(pt) == 1 {
|
||||||
|
return pt[0]
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%v", []string(pt))
|
||||||
|
}
|
||||||
|
|
||||||
type ToolFunction struct {
|
type ToolFunction struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
Parameters struct {
|
Parameters struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
Defs any `json:"$defs,omitempty"`
|
||||||
|
Items any `json:"items,omitempty"`
|
||||||
Required []string `json:"required"`
|
Required []string `json:"required"`
|
||||||
Properties map[string]struct {
|
Properties map[string]struct {
|
||||||
Type string `json:"type"`
|
Type PropertyType `json:"type"`
|
||||||
|
Items any `json:"items,omitempty"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
Enum []string `json:"enum,omitempty"`
|
Enum []any `json:"enum,omitempty"`
|
||||||
} `json:"properties"`
|
} `json:"properties"`
|
||||||
} `json:"parameters"`
|
} `json:"parameters"`
|
||||||
}
|
}
|
||||||
@@ -258,7 +307,7 @@ type EmbedRequest struct {
|
|||||||
Truncate *bool `json:"truncate,omitempty"`
|
Truncate *bool `json:"truncate,omitempty"`
|
||||||
|
|
||||||
// Options lists model-specific options.
|
// Options lists model-specific options.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]any `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// EmbedResponse is the response from [Client.Embed].
|
// EmbedResponse is the response from [Client.Embed].
|
||||||
@@ -284,7 +333,7 @@ type EmbeddingRequest struct {
|
|||||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||||
|
|
||||||
// Options lists model-specific options.
|
// Options lists model-specific options.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]any `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// EmbeddingResponse is the response from [Client.Embeddings].
|
// EmbeddingResponse is the response from [Client.Embeddings].
|
||||||
@@ -330,7 +379,7 @@ type ShowRequest struct {
|
|||||||
Template string `json:"template"`
|
Template string `json:"template"`
|
||||||
Verbose bool `json:"verbose"`
|
Verbose bool `json:"verbose"`
|
||||||
|
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]any `json:"options"`
|
||||||
|
|
||||||
// Deprecated: set the model name with Model instead
|
// Deprecated: set the model name with Model instead
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
@@ -347,6 +396,8 @@ type ShowResponse struct {
|
|||||||
Messages []Message `json:"messages,omitempty"`
|
Messages []Message `json:"messages,omitempty"`
|
||||||
ModelInfo map[string]any `json:"model_info,omitempty"`
|
ModelInfo map[string]any `json:"model_info,omitempty"`
|
||||||
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
||||||
|
Tensors []Tensor `json:"tensors,omitempty"`
|
||||||
|
Capabilities []model.Capability `json:"capabilities,omitempty"`
|
||||||
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -359,9 +410,9 @@ type CopyRequest struct {
|
|||||||
// PullRequest is the request passed to [Client.Pull].
|
// PullRequest is the request passed to [Client.Pull].
|
||||||
type PullRequest struct {
|
type PullRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Insecure bool `json:"insecure,omitempty"`
|
Insecure bool `json:"insecure,omitempty"` // Deprecated: ignored
|
||||||
Username string `json:"username"`
|
Username string `json:"username"` // Deprecated: ignored
|
||||||
Password string `json:"password"`
|
Password string `json:"password"` // Deprecated: ignored
|
||||||
Stream *bool `json:"stream,omitempty"`
|
Stream *bool `json:"stream,omitempty"`
|
||||||
|
|
||||||
// Deprecated: set the model name with Model instead
|
// Deprecated: set the model name with Model instead
|
||||||
@@ -465,6 +516,13 @@ type ModelDetails struct {
|
|||||||
QuantizationLevel string `json:"quantization_level"`
|
QuantizationLevel string `json:"quantization_level"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tensor describes the metadata for a given tensor.
|
||||||
|
type Tensor struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Shape []uint64 `json:"shape"`
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Metrics) Summary() {
|
func (m *Metrics) Summary() {
|
||||||
if m.TotalDuration > 0 {
|
if m.TotalDuration > 0 {
|
||||||
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
|
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
|
||||||
@@ -493,7 +551,7 @@ func (m *Metrics) Summary() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (opts *Options) FromMap(m map[string]interface{}) error {
|
func (opts *Options) FromMap(m map[string]any) error {
|
||||||
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||||
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
||||||
|
|
||||||
@@ -550,12 +608,12 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
|
|||||||
}
|
}
|
||||||
field.SetString(val)
|
field.SetString(val)
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
// JSON unmarshals to []interface{}, not []string
|
// JSON unmarshals to []any, not []string
|
||||||
val, ok := val.([]interface{})
|
val, ok := val.([]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("option %q must be of type array", key)
|
return fmt.Errorf("option %q must be of type array", key)
|
||||||
}
|
}
|
||||||
// convert []interface{} to []string
|
// convert []any to []string
|
||||||
slice := make([]string, len(val))
|
slice := make([]string, len(val))
|
||||||
for i, item := range val {
|
for i, item := range val {
|
||||||
str, ok := item.(string)
|
str, ok := item.(string)
|
||||||
@@ -609,7 +667,7 @@ func DefaultOptions() Options {
|
|||||||
|
|
||||||
Runner: Runner{
|
Runner: Runner{
|
||||||
// options set when the model is loaded
|
// options set when the model is loaded
|
||||||
NumCtx: 2048,
|
NumCtx: int(envconfig.ContextLength()),
|
||||||
NumBatch: 512,
|
NumBatch: 512,
|
||||||
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
|
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
|
||||||
NumThread: 0, // let the runtime decide
|
NumThread: 0, // let the runtime decide
|
||||||
@@ -662,7 +720,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// FormatParams converts specified parameter options to their correct types
|
// FormatParams converts specified parameter options to their correct types
|
||||||
func FormatParams(params map[string][]string) (map[string]interface{}, error) {
|
func FormatParams(params map[string][]string) (map[string]any, error) {
|
||||||
opts := Options{}
|
opts := Options{}
|
||||||
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
|
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
|
||||||
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
|
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
|
||||||
@@ -676,7 +734,7 @@ func FormatParams(params map[string][]string) (map[string]interface{}, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
out := make(map[string]interface{})
|
out := make(map[string]any)
|
||||||
// iterate params and set values based on json struct tags
|
// iterate params and set values based on json struct tags
|
||||||
for key, vals := range params {
|
for key, vals := range params {
|
||||||
if opt, ok := jsonOpts[key]; !ok {
|
if opt, ok := jsonOpts[key]; !ok {
|
||||||
|
|||||||
@@ -134,7 +134,7 @@ func TestUseMmapParsingFromJSON(t *testing.T) {
|
|||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
var oMap map[string]interface{}
|
var oMap map[string]any
|
||||||
err := json.Unmarshal([]byte(test.req), &oMap)
|
err := json.Unmarshal([]byte(test.req), &oMap)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
opts := DefaultOptions()
|
opts := DefaultOptions()
|
||||||
@@ -231,3 +231,144 @@ func TestMessage_UnmarshalJSON(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestToolFunction_UnmarshalJSON(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid enum with same types",
|
||||||
|
input: `{
|
||||||
|
"name": "test",
|
||||||
|
"description": "test function",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"required": ["test"],
|
||||||
|
"properties": {
|
||||||
|
"test": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "test prop",
|
||||||
|
"enum": ["a", "b", "c"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`,
|
||||||
|
wantErr: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty enum array",
|
||||||
|
input: `{
|
||||||
|
"name": "test",
|
||||||
|
"description": "test function",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"required": ["test"],
|
||||||
|
"properties": {
|
||||||
|
"test": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "test prop",
|
||||||
|
"enum": []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`,
|
||||||
|
wantErr: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var tf ToolFunction
|
||||||
|
err := json.Unmarshal([]byte(tt.input), &tf)
|
||||||
|
|
||||||
|
if tt.wantErr != "" {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), tt.wantErr)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPropertyType_UnmarshalJSON(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected PropertyType
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "string type",
|
||||||
|
input: `"string"`,
|
||||||
|
expected: PropertyType{"string"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "array of types",
|
||||||
|
input: `["string", "number"]`,
|
||||||
|
expected: PropertyType{"string", "number"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "array with single type",
|
||||||
|
input: `["string"]`,
|
||||||
|
expected: PropertyType{"string"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
var pt PropertyType
|
||||||
|
if err := json.Unmarshal([]byte(test.input), &pt); err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(pt) != len(test.expected) {
|
||||||
|
t.Errorf("Length mismatch: got %v, expected %v", len(pt), len(test.expected))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, v := range pt {
|
||||||
|
if v != test.expected[i] {
|
||||||
|
t.Errorf("Value mismatch at index %d: got %v, expected %v", i, v, test.expected[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPropertyType_MarshalJSON(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input PropertyType
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single type",
|
||||||
|
input: PropertyType{"string"},
|
||||||
|
expected: `"string"`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple types",
|
||||||
|
input: PropertyType{"string", "number"},
|
||||||
|
expected: `["string","number"]`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty type",
|
||||||
|
input: PropertyType{},
|
||||||
|
expected: `[]`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
data, err := json.Marshal(test.input)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(data) != test.expected {
|
||||||
|
t.Errorf("Marshaled data mismatch: got %v, expected %v", string(data), test.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -17,6 +17,6 @@ If you want to build the installer, youll need to install
|
|||||||
In the top directory of this repo, run the following powershell script
|
In the top directory of this repo, run the following powershell script
|
||||||
to build the ollama CLI, ollama app, and ollama installer.
|
to build the ollama CLI, ollama app, and ollama installer.
|
||||||
|
|
||||||
```
|
```powershell
|
||||||
powershell -ExecutionPolicy Bypass -File .\scripts\build_windows.ps1
|
powershell -ExecutionPolicy Bypass -File .\scripts\build_windows.ps1
|
||||||
```
|
```
|
||||||
|
|||||||
178
benchmark/server_benchmark_test.go
Normal file
178
benchmark/server_benchmark_test.go
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
package benchmark
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Command line flags
|
||||||
|
var modelFlag string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
|
||||||
|
flag.Lookup("m").DefValue = "model"
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelName returns the model name from flags, failing the test if not set
|
||||||
|
func modelName(b *testing.B) string {
|
||||||
|
if modelFlag == "" {
|
||||||
|
b.Fatal("Error: -m flag is required for benchmark tests")
|
||||||
|
}
|
||||||
|
return modelFlag
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestCase struct {
|
||||||
|
name string
|
||||||
|
prompt string
|
||||||
|
maxTokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
// runGenerateBenchmark contains the common generate and metrics logic
|
||||||
|
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
|
||||||
|
start := time.Now()
|
||||||
|
var ttft time.Duration
|
||||||
|
var metrics api.Metrics
|
||||||
|
|
||||||
|
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||||
|
if ttft == 0 && resp.Response != "" {
|
||||||
|
ttft = time.Since(start)
|
||||||
|
}
|
||||||
|
if resp.Done {
|
||||||
|
metrics = resp.Metrics
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Report custom metrics as part of the benchmark results
|
||||||
|
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
|
||||||
|
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
|
||||||
|
|
||||||
|
// Token throughput metrics
|
||||||
|
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
|
||||||
|
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
|
||||||
|
b.ReportMetric(promptThroughput, "prompt_tok/s")
|
||||||
|
b.ReportMetric(genThroughput, "gen_tok/s")
|
||||||
|
|
||||||
|
// Token counts
|
||||||
|
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
|
||||||
|
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkColdStart runs benchmarks with model loading from cold state
|
||||||
|
func BenchmarkColdStart(b *testing.B) {
|
||||||
|
client := setup(b)
|
||||||
|
tests := []TestCase{
|
||||||
|
{"short_prompt", "Write a long story", 100},
|
||||||
|
{"medium_prompt", "Write a detailed economic analysis", 500},
|
||||||
|
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
||||||
|
}
|
||||||
|
m := modelName(b)
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Set number of tokens as our throughput metric
|
||||||
|
b.SetBytes(int64(tt.maxTokens))
|
||||||
|
|
||||||
|
for b.Loop() {
|
||||||
|
b.StopTimer()
|
||||||
|
// Ensure model is unloaded before each iteration
|
||||||
|
unload(client, m, b)
|
||||||
|
b.StartTimer()
|
||||||
|
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: m,
|
||||||
|
Prompt: tt.prompt,
|
||||||
|
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||||
|
}
|
||||||
|
|
||||||
|
runGenerateBenchmark(b, ctx, client, req)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkWarmStart runs benchmarks with pre-loaded model
|
||||||
|
func BenchmarkWarmStart(b *testing.B) {
|
||||||
|
client := setup(b)
|
||||||
|
tests := []TestCase{
|
||||||
|
{"short_prompt", "Write a long story", 100},
|
||||||
|
{"medium_prompt", "Write a detailed economic analysis", 500},
|
||||||
|
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
||||||
|
}
|
||||||
|
m := modelName(b)
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Pre-warm the model
|
||||||
|
warmup(client, m, tt.prompt, b)
|
||||||
|
|
||||||
|
// Set number of tokens as our throughput metric
|
||||||
|
b.SetBytes(int64(tt.maxTokens))
|
||||||
|
|
||||||
|
for b.Loop() {
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: m,
|
||||||
|
Prompt: tt.prompt,
|
||||||
|
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||||
|
}
|
||||||
|
|
||||||
|
runGenerateBenchmark(b, ctx, client, req)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setup verifies server and model availability
|
||||||
|
func setup(b *testing.B) *api.Client {
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName(b)}); err != nil {
|
||||||
|
b.Fatalf("Model unavailable: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
// warmup ensures the model is loaded and warmed up
|
||||||
|
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
|
||||||
|
for range 3 {
|
||||||
|
err := client.Generate(
|
||||||
|
context.Background(),
|
||||||
|
&api.GenerateRequest{
|
||||||
|
Model: model,
|
||||||
|
Prompt: prompt,
|
||||||
|
Options: map[string]any{"num_predict": 50, "temperature": 0.1},
|
||||||
|
},
|
||||||
|
func(api.GenerateResponse) error { return nil },
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
b.Logf("Error during model warm-up: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// unload forces model unloading using KeepAlive: 0 parameter
|
||||||
|
func unload(client *api.Client, model string, b *testing.B) {
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: model,
|
||||||
|
KeepAlive: &api.Duration{Duration: 0},
|
||||||
|
}
|
||||||
|
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
|
||||||
|
b.Logf("Unload error: %v", err)
|
||||||
|
}
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
}
|
||||||
122
cmd/cmd.go
122
cmd/cmd.go
@@ -18,6 +18,8 @@ import (
|
|||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"slices"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -34,10 +36,9 @@ import (
|
|||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/llama"
|
|
||||||
"github.com/ollama/ollama/llama/runner"
|
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/progress"
|
"github.com/ollama/ollama/progress"
|
||||||
|
"github.com/ollama/ollama/runner"
|
||||||
"github.com/ollama/ollama/server"
|
"github.com/ollama/ollama/server"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
@@ -256,6 +257,7 @@ func StopHandler(cmd *cobra.Command, args []string) error {
|
|||||||
if strings.Contains(err.Error(), "not found") {
|
if strings.Contains(err.Error(), "not found") {
|
||||||
return fmt.Errorf("couldn't find model \"%s\" to stop", args[0])
|
return fmt.Errorf("couldn't find model \"%s\" to stop", args[0])
|
||||||
}
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -266,7 +268,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
opts := runOptions{
|
opts := runOptions{
|
||||||
Model: args[0],
|
Model: args[0],
|
||||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||||
Options: map[string]interface{}{},
|
Options: map[string]any{},
|
||||||
}
|
}
|
||||||
|
|
||||||
format, err := cmd.Flags().GetString("format")
|
format, err := cmd.Flags().GetString("format")
|
||||||
@@ -338,7 +340,21 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
opts.MultiModal = len(info.ProjectorInfo) != 0
|
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
|
||||||
|
|
||||||
|
// TODO: remove the projector info and vision info checks below,
|
||||||
|
// these are left in for backwards compatibility with older servers
|
||||||
|
// that don't have the capabilities field in the model info
|
||||||
|
if len(info.ProjectorInfo) != 0 {
|
||||||
|
opts.MultiModal = true
|
||||||
|
}
|
||||||
|
for k := range info.ModelInfo {
|
||||||
|
if strings.Contains(k, ".vision.") {
|
||||||
|
opts.MultiModal = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
opts.ParentModel = info.Details.ParentModel
|
opts.ParentModel = info.Details.ParentModel
|
||||||
|
|
||||||
if interactive {
|
if interactive {
|
||||||
@@ -559,8 +575,9 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
|
|||||||
parameters, errParams := cmd.Flags().GetBool("parameters")
|
parameters, errParams := cmd.Flags().GetBool("parameters")
|
||||||
system, errSystem := cmd.Flags().GetBool("system")
|
system, errSystem := cmd.Flags().GetBool("system")
|
||||||
template, errTemplate := cmd.Flags().GetBool("template")
|
template, errTemplate := cmd.Flags().GetBool("template")
|
||||||
|
verbose, errVerbose := cmd.Flags().GetBool("verbose")
|
||||||
|
|
||||||
for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate} {
|
for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate, errVerbose} {
|
||||||
if boolErr != nil {
|
if boolErr != nil {
|
||||||
return errors.New("error retrieving flags")
|
return errors.New("error retrieving flags")
|
||||||
}
|
}
|
||||||
@@ -598,7 +615,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
|
return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
|
||||||
}
|
}
|
||||||
|
|
||||||
req := api.ShowRequest{Name: args[0]}
|
req := api.ShowRequest{Name: args[0], Verbose: verbose}
|
||||||
resp, err := client.Show(cmd.Context(), &req)
|
resp, err := client.Show(cmd.Context(), &req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -621,10 +638,10 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return showInfo(resp, os.Stdout)
|
return showInfo(resp, verbose, os.Stdout)
|
||||||
}
|
}
|
||||||
|
|
||||||
func showInfo(resp *api.ShowResponse, w io.Writer) error {
|
func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
||||||
tableRender := func(header string, rows func() [][]string) {
|
tableRender := func(header string, rows func() [][]string) {
|
||||||
fmt.Fprintln(w, " ", header)
|
fmt.Fprintln(w, " ", header)
|
||||||
table := tablewriter.NewWriter(w)
|
table := tablewriter.NewWriter(w)
|
||||||
@@ -658,6 +675,15 @@ func showInfo(resp *api.ShowResponse, w io.Writer) error {
|
|||||||
return
|
return
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if len(resp.Capabilities) > 0 {
|
||||||
|
tableRender("Capabilities", func() (rows [][]string) {
|
||||||
|
for _, capability := range resp.Capabilities {
|
||||||
|
rows = append(rows, []string{"", capability.String()})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
if resp.ProjectorInfo != nil {
|
if resp.ProjectorInfo != nil {
|
||||||
tableRender("Projector", func() (rows [][]string) {
|
tableRender("Projector", func() (rows [][]string) {
|
||||||
arch := resp.ProjectorInfo["general.architecture"].(string)
|
arch := resp.ProjectorInfo["general.architecture"].(string)
|
||||||
@@ -681,6 +707,47 @@ func showInfo(resp *api.ShowResponse, w io.Writer) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if resp.ModelInfo != nil && verbose {
|
||||||
|
tableRender("Metadata", func() (rows [][]string) {
|
||||||
|
keys := make([]string, 0, len(resp.ModelInfo))
|
||||||
|
for k := range resp.ModelInfo {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
|
||||||
|
for _, k := range keys {
|
||||||
|
var v string
|
||||||
|
switch vData := resp.ModelInfo[k].(type) {
|
||||||
|
case bool:
|
||||||
|
v = fmt.Sprintf("%t", vData)
|
||||||
|
case string:
|
||||||
|
v = vData
|
||||||
|
case float64:
|
||||||
|
v = fmt.Sprintf("%g", vData)
|
||||||
|
case []any:
|
||||||
|
n := 3
|
||||||
|
if len(vData) < n {
|
||||||
|
n = len(vData)
|
||||||
|
}
|
||||||
|
v = fmt.Sprintf("%v", vData[:n])
|
||||||
|
default:
|
||||||
|
v = fmt.Sprintf("%T", vData)
|
||||||
|
}
|
||||||
|
rows = append(rows, []string{"", k, v})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resp.Tensors) > 0 && verbose {
|
||||||
|
tableRender("Tensors", func() (rows [][]string) {
|
||||||
|
for _, t := range resp.Tensors {
|
||||||
|
rows = append(rows, []string{"", t.Name, t.Type, fmt.Sprint(t.Shape)})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
head := func(s string, n int) (rows [][]string) {
|
head := func(s string, n int) (rows [][]string) {
|
||||||
scanner := bufio.NewScanner(strings.NewReader(s))
|
scanner := bufio.NewScanner(strings.NewReader(s))
|
||||||
for scanner.Scan() && (len(rows) < n || n < 0) {
|
for scanner.Scan() && (len(rows) < n || n < 0) {
|
||||||
@@ -741,13 +808,38 @@ func PullHandler(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
fn := func(resp api.ProgressResponse) error {
|
fn := func(resp api.ProgressResponse) error {
|
||||||
if resp.Digest != "" {
|
if resp.Digest != "" {
|
||||||
|
if resp.Completed == 0 {
|
||||||
|
// This is the initial status update for the
|
||||||
|
// layer, which the server sends before
|
||||||
|
// beginning the download, for clients to
|
||||||
|
// compute total size and prepare for
|
||||||
|
// downloads, if needed.
|
||||||
|
//
|
||||||
|
// Skipping this here to avoid showing a 0%
|
||||||
|
// progress bar, which *should* clue the user
|
||||||
|
// into the fact that many things are being
|
||||||
|
// downloaded and that the current active
|
||||||
|
// download is not that last. However, in rare
|
||||||
|
// cases it seems to be triggering to some, and
|
||||||
|
// it isn't worth explaining, so just ignore
|
||||||
|
// and regress to the old UI that keeps giving
|
||||||
|
// you the "But wait, there is more!" after
|
||||||
|
// each "100% done" bar, which is "better."
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if spinner != nil {
|
if spinner != nil {
|
||||||
spinner.Stop()
|
spinner.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
bar, ok := bars[resp.Digest]
|
bar, ok := bars[resp.Digest]
|
||||||
if !ok {
|
if !ok {
|
||||||
bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
|
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
if isDigest {
|
||||||
|
name = name[:min(12, len(name))]
|
||||||
|
}
|
||||||
|
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
|
||||||
bars[resp.Digest] = bar
|
bars[resp.Digest] = bar
|
||||||
p.Add(resp.Digest, bar)
|
p.Add(resp.Digest, bar)
|
||||||
}
|
}
|
||||||
@@ -767,11 +859,7 @@ func PullHandler(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
request := api.PullRequest{Name: args[0], Insecure: insecure}
|
request := api.PullRequest{Name: args[0], Insecure: insecure}
|
||||||
if err := client.Pull(cmd.Context(), &request, fn); err != nil {
|
return client.Pull(cmd.Context(), &request, fn)
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type generateContextKey string
|
type generateContextKey string
|
||||||
@@ -785,7 +873,7 @@ type runOptions struct {
|
|||||||
Format string
|
Format string
|
||||||
System string
|
System string
|
||||||
Images []api.ImageData
|
Images []api.ImageData
|
||||||
Options map[string]interface{}
|
Options map[string]any
|
||||||
MultiModal bool
|
MultiModal bool
|
||||||
KeepAlive *api.Duration
|
KeepAlive *api.Duration
|
||||||
}
|
}
|
||||||
@@ -1187,6 +1275,7 @@ func NewCLI() *cobra.Command {
|
|||||||
showCmd.Flags().Bool("parameters", false, "Show parameters of a model")
|
showCmd.Flags().Bool("parameters", false, "Show parameters of a model")
|
||||||
showCmd.Flags().Bool("template", false, "Show template of a model")
|
showCmd.Flags().Bool("template", false, "Show template of a model")
|
||||||
showCmd.Flags().Bool("system", false, "Show system message of a model")
|
showCmd.Flags().Bool("system", false, "Show system message of a model")
|
||||||
|
showCmd.Flags().BoolP("verbose", "v", false, "Show detailed model information")
|
||||||
|
|
||||||
runCmd := &cobra.Command{
|
runCmd := &cobra.Command{
|
||||||
Use: "run MODEL [PROMPT]",
|
Use: "run MODEL [PROMPT]",
|
||||||
@@ -1271,7 +1360,6 @@ func NewCLI() *cobra.Command {
|
|||||||
|
|
||||||
runnerCmd := &cobra.Command{
|
runnerCmd := &cobra.Command{
|
||||||
Use: "runner",
|
Use: "runner",
|
||||||
Short: llama.PrintSystemInfo(),
|
|
||||||
Hidden: true,
|
Hidden: true,
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
return runner.Execute(os.Args[1:])
|
return runner.Execute(os.Args[1:])
|
||||||
@@ -1314,12 +1402,12 @@ func NewCLI() *cobra.Command {
|
|||||||
envVars["OLLAMA_NOPRUNE"],
|
envVars["OLLAMA_NOPRUNE"],
|
||||||
envVars["OLLAMA_ORIGINS"],
|
envVars["OLLAMA_ORIGINS"],
|
||||||
envVars["OLLAMA_SCHED_SPREAD"],
|
envVars["OLLAMA_SCHED_SPREAD"],
|
||||||
envVars["OLLAMA_TMPDIR"],
|
|
||||||
envVars["OLLAMA_FLASH_ATTENTION"],
|
envVars["OLLAMA_FLASH_ATTENTION"],
|
||||||
envVars["OLLAMA_KV_CACHE_TYPE"],
|
envVars["OLLAMA_KV_CACHE_TYPE"],
|
||||||
envVars["OLLAMA_LLM_LIBRARY"],
|
envVars["OLLAMA_LLM_LIBRARY"],
|
||||||
envVars["OLLAMA_GPU_OVERHEAD"],
|
envVars["OLLAMA_GPU_OVERHEAD"],
|
||||||
envVars["OLLAMA_LOAD_TIMEOUT"],
|
envVars["OLLAMA_LOAD_TIMEOUT"],
|
||||||
|
envVars["OLLAMA_CONTEXT_LENGTH"],
|
||||||
})
|
})
|
||||||
default:
|
default:
|
||||||
appendEnvDocs(cmd, envs)
|
appendEnvDocs(cmd, envs)
|
||||||
|
|||||||
316
cmd/cmd_test.go
316
cmd/cmd_test.go
@@ -10,11 +10,13 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestShowInfo(t *testing.T) {
|
func TestShowInfo(t *testing.T) {
|
||||||
@@ -26,7 +28,7 @@ func TestShowInfo(t *testing.T) {
|
|||||||
ParameterSize: "7B",
|
ParameterSize: "7B",
|
||||||
QuantizationLevel: "FP16",
|
QuantizationLevel: "FP16",
|
||||||
},
|
},
|
||||||
}, &b); err != nil {
|
}, false, &b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -56,7 +58,7 @@ func TestShowInfo(t *testing.T) {
|
|||||||
ParameterSize: "7B",
|
ParameterSize: "7B",
|
||||||
QuantizationLevel: "FP16",
|
QuantizationLevel: "FP16",
|
||||||
},
|
},
|
||||||
}, &b); err != nil {
|
}, false, &b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -67,6 +69,60 @@ func TestShowInfo(t *testing.T) {
|
|||||||
embedding length 0
|
embedding length 0
|
||||||
quantization FP16
|
quantization FP16
|
||||||
|
|
||||||
|
`
|
||||||
|
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||||
|
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("verbose model", func(t *testing.T) {
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := showInfo(&api.ShowResponse{
|
||||||
|
Details: api.ModelDetails{
|
||||||
|
Family: "test",
|
||||||
|
ParameterSize: "8B",
|
||||||
|
QuantizationLevel: "FP16",
|
||||||
|
},
|
||||||
|
Parameters: `
|
||||||
|
stop up`,
|
||||||
|
ModelInfo: map[string]any{
|
||||||
|
"general.architecture": "test",
|
||||||
|
"general.parameter_count": float64(8_000_000_000),
|
||||||
|
"some.true_bool": true,
|
||||||
|
"some.false_bool": false,
|
||||||
|
"test.context_length": float64(1000),
|
||||||
|
"test.embedding_length": float64(11434),
|
||||||
|
},
|
||||||
|
Tensors: []api.Tensor{
|
||||||
|
{Name: "blk.0.attn_k.weight", Type: "BF16", Shape: []uint64{42, 3117}},
|
||||||
|
{Name: "blk.0.attn_q.weight", Type: "FP16", Shape: []uint64{3117, 42}},
|
||||||
|
},
|
||||||
|
}, true, &b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expect := ` Model
|
||||||
|
architecture test
|
||||||
|
parameters 8B
|
||||||
|
context length 1000
|
||||||
|
embedding length 11434
|
||||||
|
quantization FP16
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
stop up
|
||||||
|
|
||||||
|
Metadata
|
||||||
|
general.architecture test
|
||||||
|
general.parameter_count 8e+09
|
||||||
|
some.false_bool false
|
||||||
|
some.true_bool true
|
||||||
|
test.context_length 1000
|
||||||
|
test.embedding_length 11434
|
||||||
|
|
||||||
|
Tensors
|
||||||
|
blk.0.attn_k.weight BF16 [42 3117]
|
||||||
|
blk.0.attn_q.weight FP16 [3117 42]
|
||||||
|
|
||||||
`
|
`
|
||||||
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||||
@@ -88,7 +144,7 @@ func TestShowInfo(t *testing.T) {
|
|||||||
stop you
|
stop you
|
||||||
stop up
|
stop up
|
||||||
temperature 99`,
|
temperature 99`,
|
||||||
}, &b); err != nil {
|
}, false, &b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,7 +181,7 @@ func TestShowInfo(t *testing.T) {
|
|||||||
"clip.vision.embedding_length": float64(0),
|
"clip.vision.embedding_length": float64(0),
|
||||||
"clip.vision.projection_dim": float64(0),
|
"clip.vision.projection_dim": float64(0),
|
||||||
},
|
},
|
||||||
}, &b); err != nil {
|
}, false, &b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -158,7 +214,7 @@ func TestShowInfo(t *testing.T) {
|
|||||||
Ahoy, matey!
|
Ahoy, matey!
|
||||||
Weigh anchor!
|
Weigh anchor!
|
||||||
`,
|
`,
|
||||||
}, &b); err != nil {
|
}, false, &b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -187,7 +243,7 @@ Weigh anchor!
|
|||||||
QuantizationLevel: "FP16",
|
QuantizationLevel: "FP16",
|
||||||
},
|
},
|
||||||
License: license,
|
License: license,
|
||||||
}, &b); err != nil {
|
}, false, &b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -205,6 +261,34 @@ Weigh anchor!
|
|||||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("capabilities", func(t *testing.T) {
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := showInfo(&api.ShowResponse{
|
||||||
|
Details: api.ModelDetails{
|
||||||
|
Family: "test",
|
||||||
|
ParameterSize: "7B",
|
||||||
|
QuantizationLevel: "FP16",
|
||||||
|
},
|
||||||
|
Capabilities: []model.Capability{model.CapabilityVision, model.CapabilityTools},
|
||||||
|
}, false, &b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expect := " Model\n" +
|
||||||
|
" architecture test \n" +
|
||||||
|
" parameters 7B \n" +
|
||||||
|
" quantization FP16 \n" +
|
||||||
|
"\n" +
|
||||||
|
" Capabilities\n" +
|
||||||
|
" vision \n" +
|
||||||
|
" tools \n" +
|
||||||
|
"\n"
|
||||||
|
|
||||||
|
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||||
|
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDeleteHandler(t *testing.T) {
|
func TestDeleteHandler(t *testing.T) {
|
||||||
@@ -331,6 +415,7 @@ func TestGetModelfileName(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("temp modelfile creation failed: %v", err)
|
t.Fatalf("temp modelfile creation failed: %v", err)
|
||||||
}
|
}
|
||||||
|
defer tempFile.Close()
|
||||||
|
|
||||||
expectedFilename = tempFile.Name()
|
expectedFilename = tempFile.Name()
|
||||||
err = cmd.Flags().Set("file", expectedFilename)
|
err = cmd.Flags().Set("file", expectedFilename)
|
||||||
@@ -490,6 +575,96 @@ func TestPushHandler(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestListHandler(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
serverResponse []api.ListModelResponse
|
||||||
|
expectedError string
|
||||||
|
expectedOutput string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "list all models",
|
||||||
|
args: []string{},
|
||||||
|
serverResponse: []api.ListModelResponse{
|
||||||
|
{Name: "model1", Digest: "sha256:abc123", Size: 1024, ModifiedAt: time.Now().Add(-24 * time.Hour)},
|
||||||
|
{Name: "model2", Digest: "sha256:def456", Size: 2048, ModifiedAt: time.Now().Add(-48 * time.Hour)},
|
||||||
|
},
|
||||||
|
expectedOutput: "NAME ID SIZE MODIFIED \n" +
|
||||||
|
"model1 sha256:abc12 1.0 KB 24 hours ago \n" +
|
||||||
|
"model2 sha256:def45 2.0 KB 2 days ago \n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filter models by prefix",
|
||||||
|
args: []string{"model1"},
|
||||||
|
serverResponse: []api.ListModelResponse{
|
||||||
|
{Name: "model1", Digest: "sha256:abc123", Size: 1024, ModifiedAt: time.Now().Add(-24 * time.Hour)},
|
||||||
|
{Name: "model2", Digest: "sha256:def456", Size: 2048, ModifiedAt: time.Now().Add(-24 * time.Hour)},
|
||||||
|
},
|
||||||
|
expectedOutput: "NAME ID SIZE MODIFIED \n" +
|
||||||
|
"model1 sha256:abc12 1.0 KB 24 hours ago \n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "server error",
|
||||||
|
args: []string{},
|
||||||
|
expectedError: "server error",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/api/tags" || r.Method != http.MethodGet {
|
||||||
|
t.Errorf("unexpected request to %s %s", r.Method, r.URL.Path)
|
||||||
|
http.Error(w, "not found", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.expectedError != "" {
|
||||||
|
http.Error(w, tt.expectedError, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response := api.ListResponse{Models: tt.serverResponse}
|
||||||
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer mockServer.Close()
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.SetContext(context.TODO())
|
||||||
|
|
||||||
|
// Capture stdout
|
||||||
|
oldStdout := os.Stdout
|
||||||
|
r, w, _ := os.Pipe()
|
||||||
|
os.Stdout = w
|
||||||
|
|
||||||
|
err := ListHandler(cmd, tt.args)
|
||||||
|
|
||||||
|
// Restore stdout and get output
|
||||||
|
w.Close()
|
||||||
|
os.Stdout = oldStdout
|
||||||
|
output, _ := io.ReadAll(r)
|
||||||
|
|
||||||
|
if tt.expectedError == "" {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
if got := string(output); got != tt.expectedOutput {
|
||||||
|
t.Errorf("expected output:\n%s\ngot:\n%s", tt.expectedOutput, got)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
|
||||||
|
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCreateHandler(t *testing.T) {
|
func TestCreateHandler(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -616,3 +791,132 @@ func TestCreateHandler(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewCreateRequest(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
from string
|
||||||
|
opts runOptions
|
||||||
|
expected *api.CreateRequest
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"basic test",
|
||||||
|
"newmodel",
|
||||||
|
runOptions{
|
||||||
|
Model: "mymodel",
|
||||||
|
ParentModel: "",
|
||||||
|
Prompt: "You are a fun AI agent",
|
||||||
|
Messages: []api.Message{},
|
||||||
|
WordWrap: true,
|
||||||
|
},
|
||||||
|
&api.CreateRequest{
|
||||||
|
From: "mymodel",
|
||||||
|
Model: "newmodel",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"parent model test",
|
||||||
|
"newmodel",
|
||||||
|
runOptions{
|
||||||
|
Model: "mymodel",
|
||||||
|
ParentModel: "parentmodel",
|
||||||
|
Messages: []api.Message{},
|
||||||
|
WordWrap: true,
|
||||||
|
},
|
||||||
|
&api.CreateRequest{
|
||||||
|
From: "parentmodel",
|
||||||
|
Model: "newmodel",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"parent model as filepath test",
|
||||||
|
"newmodel",
|
||||||
|
runOptions{
|
||||||
|
Model: "mymodel",
|
||||||
|
ParentModel: "/some/file/like/etc/passwd",
|
||||||
|
Messages: []api.Message{},
|
||||||
|
WordWrap: true,
|
||||||
|
},
|
||||||
|
&api.CreateRequest{
|
||||||
|
From: "mymodel",
|
||||||
|
Model: "newmodel",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"parent model as windows filepath test",
|
||||||
|
"newmodel",
|
||||||
|
runOptions{
|
||||||
|
Model: "mymodel",
|
||||||
|
ParentModel: "D:\\some\\file\\like\\etc\\passwd",
|
||||||
|
Messages: []api.Message{},
|
||||||
|
WordWrap: true,
|
||||||
|
},
|
||||||
|
&api.CreateRequest{
|
||||||
|
From: "mymodel",
|
||||||
|
Model: "newmodel",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"options test",
|
||||||
|
"newmodel",
|
||||||
|
runOptions{
|
||||||
|
Model: "mymodel",
|
||||||
|
ParentModel: "parentmodel",
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
&api.CreateRequest{
|
||||||
|
From: "parentmodel",
|
||||||
|
Model: "newmodel",
|
||||||
|
Parameters: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"messages test",
|
||||||
|
"newmodel",
|
||||||
|
runOptions{
|
||||||
|
Model: "mymodel",
|
||||||
|
ParentModel: "parentmodel",
|
||||||
|
System: "You are a fun AI agent",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "hello there!",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "hello to you!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
WordWrap: true,
|
||||||
|
},
|
||||||
|
&api.CreateRequest{
|
||||||
|
From: "parentmodel",
|
||||||
|
Model: "newmodel",
|
||||||
|
System: "You are a fun AI agent",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "hello there!",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "hello to you!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
actual := NewCreateRequest(tt.from, tt.opts)
|
||||||
|
if !cmp.Equal(actual, tt.expected) {
|
||||||
|
t.Errorf("expected output %#v, got %#v", tt.expected, actual)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/readline"
|
"github.com/ollama/ollama/readline"
|
||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
type MultilineState int
|
type MultilineState int
|
||||||
@@ -195,6 +196,10 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
opts.Messages = []api.Message{}
|
opts.Messages = []api.Message{}
|
||||||
fmt.Printf("Loading model '%s'\n", opts.Model)
|
fmt.Printf("Loading model '%s'\n", opts.Model)
|
||||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||||
|
if strings.Contains(err.Error(), "not found") {
|
||||||
|
fmt.Printf("error: %v\n", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
@@ -343,7 +348,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
|
|
||||||
switch args[1] {
|
switch args[1] {
|
||||||
case "info":
|
case "info":
|
||||||
_ = showInfo(resp, os.Stderr)
|
_ = showInfo(resp, false, os.Stderr)
|
||||||
case "license":
|
case "license":
|
||||||
if resp.License == "" {
|
if resp.License == "" {
|
||||||
fmt.Println("No license was specified for this model.")
|
fmt.Println("No license was specified for this model.")
|
||||||
@@ -455,9 +460,16 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
|
func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
|
||||||
|
parentModel := opts.ParentModel
|
||||||
|
|
||||||
|
modelName := model.ParseName(parentModel)
|
||||||
|
if !modelName.IsValid() {
|
||||||
|
parentModel = ""
|
||||||
|
}
|
||||||
|
|
||||||
req := &api.CreateRequest{
|
req := &api.CreateRequest{
|
||||||
Name: name,
|
Model: name,
|
||||||
From: cmp.Or(opts.ParentModel, opts.Model),
|
From: cmp.Or(parentModel, opts.Model),
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.System != "" {
|
if opts.System != "" {
|
||||||
@@ -491,6 +503,7 @@ func normalizeFilePath(fp string) string {
|
|||||||
"\\\\", "\\", // Escaped backslash
|
"\\\\", "\\", // Escaped backslash
|
||||||
"\\*", "*", // Escaped asterisk
|
"\\*", "*", // Escaped asterisk
|
||||||
"\\?", "?", // Escaped question mark
|
"\\?", "?", // Escaped question mark
|
||||||
|
"\\~", "~", // Escaped tilde
|
||||||
).Replace(fp)
|
).Replace(fp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llama/runner"
|
"github.com/ollama/ollama/runner"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|||||||
@@ -7,14 +7,20 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ModelParameters struct {
|
type ModelParameters struct {
|
||||||
Architectures []string `json:"architectures"`
|
Architectures []string `json:"architectures"`
|
||||||
VocabSize uint32 `json:"vocab_size"`
|
VocabSize uint32 `json:"vocab_size"`
|
||||||
|
TextModel TextParameters `json:"text_config"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TextParameters struct {
|
||||||
|
VocabSize uint32 `json:"vocab_size"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AdapterParameters struct {
|
type AdapterParameters struct {
|
||||||
@@ -27,8 +33,8 @@ type AdapterParameters struct {
|
|||||||
} `json:"lora_parameters"`
|
} `json:"lora_parameters"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ModelParameters) KV(t *Tokenizer) llm.KV {
|
func (ModelParameters) KV(t *Tokenizer) ggml.KV {
|
||||||
kv := llm.KV{
|
kv := ggml.KV{
|
||||||
"general.file_type": uint32(1),
|
"general.file_type": uint32(1),
|
||||||
"general.quantization_version": uint32(2),
|
"general.quantization_version": uint32(2),
|
||||||
"tokenizer.ggml.pre": t.Pre,
|
"tokenizer.ggml.pre": t.Pre,
|
||||||
@@ -54,7 +60,7 @@ func (ModelParameters) KV(t *Tokenizer) llm.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p AdapterParameters) KV() llm.KV {
|
func (p AdapterParameters) KV() ggml.KV {
|
||||||
var alpha float32
|
var alpha float32
|
||||||
if p.LoraParameters.Alpha == 0 {
|
if p.LoraParameters.Alpha == 0 {
|
||||||
alpha = float32(p.Alpha)
|
alpha = float32(p.Alpha)
|
||||||
@@ -62,7 +68,7 @@ func (p AdapterParameters) KV() llm.KV {
|
|||||||
alpha = p.LoraParameters.Alpha
|
alpha = p.LoraParameters.Alpha
|
||||||
}
|
}
|
||||||
|
|
||||||
kv := llm.KV{
|
kv := ggml.KV{
|
||||||
"adapter.lora.alpha": alpha,
|
"adapter.lora.alpha": alpha,
|
||||||
"adapter.type": "lora",
|
"adapter.type": "lora",
|
||||||
"general.file_type": uint32(1),
|
"general.file_type": uint32(1),
|
||||||
@@ -79,27 +85,17 @@ func (ModelParameters) specialTokenTypes() []string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ModelParameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor) error {
|
|
||||||
return llm.WriteGGUF(ws, kv, ts)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (AdapterParameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor) error {
|
|
||||||
return llm.WriteGGUF(ws, kv, ts)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ModelConverter interface {
|
type ModelConverter interface {
|
||||||
// KV maps parameters to LLM key-values
|
// KV maps parameters to LLM key-values
|
||||||
KV(*Tokenizer) llm.KV
|
KV(*Tokenizer) ggml.KV
|
||||||
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
|
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
|
||||||
Tensors([]Tensor) []llm.Tensor
|
Tensors([]Tensor) []ggml.Tensor
|
||||||
// Replacements returns a list of string pairs to replace in tensor names.
|
// Replacements returns a list of string pairs to replace in tensor names.
|
||||||
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
|
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
|
||||||
Replacements() []string
|
Replacements() []string
|
||||||
|
|
||||||
// specialTokenTypes returns any special token types the model uses
|
// specialTokenTypes returns any special token types the model uses
|
||||||
specialTokenTypes() []string
|
specialTokenTypes() []string
|
||||||
// writeFile writes the model to the provided io.WriteSeeker
|
|
||||||
writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type moreParser interface {
|
type moreParser interface {
|
||||||
@@ -108,17 +104,15 @@ type moreParser interface {
|
|||||||
|
|
||||||
type AdapterConverter interface {
|
type AdapterConverter interface {
|
||||||
// KV maps parameters to LLM key-values
|
// KV maps parameters to LLM key-values
|
||||||
KV(llm.KV) llm.KV
|
KV(ggml.KV) ggml.KV
|
||||||
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
|
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
|
||||||
Tensors([]Tensor) []llm.Tensor
|
Tensors([]Tensor) []ggml.Tensor
|
||||||
// Replacements returns a list of string pairs to replace in tensor names.
|
// Replacements returns a list of string pairs to replace in tensor names.
|
||||||
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
|
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
|
||||||
Replacements() []string
|
Replacements() []string
|
||||||
|
|
||||||
writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV llm.KV) error {
|
func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV ggml.KV) error {
|
||||||
bts, err := fs.ReadFile(fsys, "adapter_config.json")
|
bts, err := fs.ReadFile(fsys, "adapter_config.json")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -153,7 +147,7 @@ func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV llm.KV) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return conv.writeFile(ws, conv.KV(baseKV), conv.Tensors(ts))
|
return writeFile(ws, conv.KV(baseKV), conv.Tensors(ts))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
|
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
|
||||||
@@ -177,14 +171,20 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
|||||||
|
|
||||||
var conv ModelConverter
|
var conv ModelConverter
|
||||||
switch p.Architectures[0] {
|
switch p.Architectures[0] {
|
||||||
case "LlamaForCausalLM", "MistralForCausalLM":
|
case "LlamaForCausalLM":
|
||||||
conv = &llamaModel{}
|
conv = &llamaModel{}
|
||||||
|
case "Llama4ForConditionalGeneration":
|
||||||
|
conv = &llama4Model{}
|
||||||
|
case "Mistral3ForConditionalGeneration":
|
||||||
|
conv = &mistral3Model{}
|
||||||
case "MixtralForCausalLM":
|
case "MixtralForCausalLM":
|
||||||
conv = &mixtralModel{}
|
conv = &mixtralModel{}
|
||||||
case "GemmaForCausalLM":
|
case "GemmaForCausalLM":
|
||||||
conv = &gemmaModel{}
|
conv = &gemmaModel{}
|
||||||
case "Gemma2ForCausalLM":
|
case "Gemma2ForCausalLM":
|
||||||
conv = &gemma2Model{}
|
conv = &gemma2Model{}
|
||||||
|
case "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration":
|
||||||
|
conv = &gemma3Model{Architecture: p.Architectures[0]}
|
||||||
case "Phi3ForCausalLM":
|
case "Phi3ForCausalLM":
|
||||||
conv = &phi3Model{}
|
conv = &phi3Model{}
|
||||||
case "Qwen2ForCausalLM":
|
case "Qwen2ForCausalLM":
|
||||||
@@ -194,7 +194,7 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
|||||||
case "CohereForCausalLM":
|
case "CohereForCausalLM":
|
||||||
conv = &commandrModel{}
|
conv = &commandrModel{}
|
||||||
default:
|
default:
|
||||||
return errors.New("unsupported architecture")
|
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(bts, conv); err != nil {
|
if err := json.Unmarshal(bts, conv); err != nil {
|
||||||
@@ -213,7 +213,14 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
vocabSize := int(p.VocabSize)
|
vocabSize := int(p.VocabSize)
|
||||||
|
if vocabSize == 0 {
|
||||||
|
tVocabSize := int(p.TextModel.VocabSize)
|
||||||
|
vocabSize = tVocabSize
|
||||||
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
|
case vocabSize == 0:
|
||||||
|
slog.Warn("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens))
|
||||||
case vocabSize > len(t.Vocabulary.Tokens):
|
case vocabSize > len(t.Vocabulary.Tokens):
|
||||||
slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
|
slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
|
||||||
for i := range vocabSize - len(t.Vocabulary.Tokens) {
|
for i := range vocabSize - len(t.Vocabulary.Tokens) {
|
||||||
@@ -232,5 +239,13 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return conv.writeFile(ws, conv.KV(t), conv.Tensors(ts))
|
return writeFile(ws, conv.KV(t), conv.Tensors(ts))
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeFile(ws io.WriteSeeker, kv ggml.KV, ts []ggml.Tensor) error {
|
||||||
|
for i := range ts {
|
||||||
|
ts[i].Shape = slices.Clone(ts[i].Shape)
|
||||||
|
slices.Reverse(ts[i].Shape)
|
||||||
|
}
|
||||||
|
return ggml.WriteGGUF(ws, kv, ts)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type bertModel struct {
|
type bertModel struct {
|
||||||
@@ -85,7 +85,7 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *bertModel) KV(t *Tokenizer) llm.KV {
|
func (p *bertModel) KV(t *Tokenizer) ggml.KV {
|
||||||
kv := p.ModelParameters.KV(t)
|
kv := p.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "bert"
|
kv["general.architecture"] = "bert"
|
||||||
kv["bert.attention.causal"] = false
|
kv["bert.attention.causal"] = false
|
||||||
@@ -132,8 +132,8 @@ func (p *bertModel) KV(t *Tokenizer) llm.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *bertModel) Tensors(ts []Tensor) []llm.Tensor {
|
func (p *bertModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||||
var out []llm.Tensor
|
var out []ggml.Tensor
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
if slices.Contains([]string{
|
if slices.Contains([]string{
|
||||||
"embeddings.position_ids",
|
"embeddings.position_ids",
|
||||||
@@ -143,7 +143,7 @@ func (p *bertModel) Tensors(ts []Tensor) []llm.Tensor {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
out = append(out, llm.Tensor{
|
out = append(out, ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: t.Shape(),
|
Shape: t.Shape(),
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package convert
|
|||||||
import (
|
import (
|
||||||
"cmp"
|
"cmp"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type commandrModel struct {
|
type commandrModel struct {
|
||||||
@@ -24,7 +24,7 @@ type commandrModel struct {
|
|||||||
|
|
||||||
var _ ModelConverter = (*commandrModel)(nil)
|
var _ ModelConverter = (*commandrModel)(nil)
|
||||||
|
|
||||||
func (p *commandrModel) KV(t *Tokenizer) llm.KV {
|
func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
|
||||||
kv := p.ModelParameters.KV(t)
|
kv := p.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "command-r"
|
kv["general.architecture"] = "command-r"
|
||||||
kv["general.name"] = "command-r"
|
kv["general.name"] = "command-r"
|
||||||
@@ -43,10 +43,10 @@ func (p *commandrModel) KV(t *Tokenizer) llm.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *commandrModel) Tensors(ts []Tensor) []llm.Tensor {
|
func (p *commandrModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||||
var out []llm.Tensor
|
var out []ggml.Tensor
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
out = append(out, llm.Tensor{
|
out = append(out, ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: t.Shape(),
|
Shape: t.Shape(),
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"github.com/pdevine/tensor"
|
"github.com/pdevine/tensor"
|
||||||
"github.com/pdevine/tensor/native"
|
"github.com/pdevine/tensor/native"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type gemmaModel struct {
|
type gemmaModel struct {
|
||||||
@@ -23,7 +23,7 @@ type gemmaModel struct {
|
|||||||
|
|
||||||
var _ ModelConverter = (*gemmaModel)(nil)
|
var _ ModelConverter = (*gemmaModel)(nil)
|
||||||
|
|
||||||
func (p *gemmaModel) KV(t *Tokenizer) llm.KV {
|
func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
|
||||||
kv := p.ModelParameters.KV(t)
|
kv := p.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "gemma"
|
kv["general.architecture"] = "gemma"
|
||||||
kv["gemma.context_length"] = p.MaxPositionEmbeddings
|
kv["gemma.context_length"] = p.MaxPositionEmbeddings
|
||||||
@@ -42,14 +42,14 @@ func (p *gemmaModel) KV(t *Tokenizer) llm.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *gemmaModel) Tensors(ts []Tensor) []llm.Tensor {
|
func (p *gemmaModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||||
var out []llm.Tensor
|
var out []ggml.Tensor
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
if strings.HasSuffix(t.Name(), "_norm.weight") {
|
if !strings.HasPrefix(t.Name(), "v.") && strings.HasSuffix(t.Name(), "_norm.weight") {
|
||||||
t.SetRepacker(p.addOne)
|
t.SetRepacker(p.addOne)
|
||||||
}
|
}
|
||||||
|
|
||||||
out = append(out, llm.Tensor{
|
out = append(out, ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: t.Shape(),
|
Shape: t.Shape(),
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
package convert
|
package convert
|
||||||
|
|
||||||
import (
|
import "github.com/ollama/ollama/fs/ggml"
|
||||||
"github.com/ollama/ollama/llm"
|
|
||||||
)
|
|
||||||
|
|
||||||
type gemma2Model struct {
|
type gemma2Model struct {
|
||||||
gemmaModel
|
gemmaModel
|
||||||
@@ -11,7 +9,7 @@ type gemma2Model struct {
|
|||||||
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *gemma2Model) KV(t *Tokenizer) llm.KV {
|
func (p *gemma2Model) KV(t *Tokenizer) ggml.KV {
|
||||||
kv := p.ModelParameters.KV(t)
|
kv := p.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "gemma2"
|
kv["general.architecture"] = "gemma2"
|
||||||
kv["gemma2.context_length"] = p.MaxPositionEmbeddings
|
kv["gemma2.context_length"] = p.MaxPositionEmbeddings
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"github.com/pdevine/tensor"
|
"github.com/pdevine/tensor"
|
||||||
"github.com/pdevine/tensor/native"
|
"github.com/pdevine/tensor/native"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type gemma2Adapter struct {
|
type gemma2Adapter struct {
|
||||||
@@ -15,14 +15,14 @@ type gemma2Adapter struct {
|
|||||||
|
|
||||||
var _ AdapterConverter = (*gemma2Adapter)(nil)
|
var _ AdapterConverter = (*gemma2Adapter)(nil)
|
||||||
|
|
||||||
func (p *gemma2Adapter) KV(baseKV llm.KV) llm.KV {
|
func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
|
||||||
kv := p.AdapterParameters.KV()
|
kv := p.AdapterParameters.KV()
|
||||||
kv["general.architecture"] = "gemma2"
|
kv["general.architecture"] = "gemma2"
|
||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *gemma2Adapter) Tensors(ts []Tensor) []llm.Tensor {
|
func (p *gemma2Adapter) Tensors(ts []Tensor) []ggml.Tensor {
|
||||||
var out []llm.Tensor
|
var out []ggml.Tensor
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
shape := t.Shape()
|
shape := t.Shape()
|
||||||
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
|
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
|
||||||
@@ -31,7 +31,7 @@ func (p *gemma2Adapter) Tensors(ts []Tensor) []llm.Tensor {
|
|||||||
t.SetRepacker(p.repack)
|
t.SetRepacker(p.repack)
|
||||||
}
|
}
|
||||||
|
|
||||||
out = append(out, llm.Tensor{
|
out = append(out, ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: t.Shape(),
|
Shape: t.Shape(),
|
||||||
|
|||||||
142
convert/convert_gemma3.go
Normal file
142
convert/convert_gemma3.go
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type gemma3Model struct {
|
||||||
|
gemmaModel
|
||||||
|
Architecture string
|
||||||
|
TextModel struct {
|
||||||
|
HeadDim uint32 `json:"head_dim"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
SlidingWindow uint32 `json:"sliding_window"`
|
||||||
|
} `json:"text_config"`
|
||||||
|
VisionModel struct {
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"` // attention.head_count 16
|
||||||
|
LayerNormEpsilon float32 `json:"layer_norm_eps"` // attention.layer_norm_epsilon 1e-05
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"` // block_count 32
|
||||||
|
HiddenSize uint32 `json:"hidden_size"` // embedding_length 1280
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"` // feed_forward_length 5120
|
||||||
|
ImageSize uint32 `json:"image_size"` // image_size 560
|
||||||
|
NumChannels uint32 `json:"num_channels"` // num_channels 3
|
||||||
|
PatchSize uint32 `json:"patch_size"` // patch_size 14
|
||||||
|
} `json:"vision_config"`
|
||||||
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
|
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||||
|
HeadDim uint32 `json:"head_dim"`
|
||||||
|
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
||||||
|
RopeLocalTheta float32 `json:"rope_local_base_freq"`
|
||||||
|
RopeGlobalTheta float32 `json:"rope_global_base_freq"`
|
||||||
|
SlidingWindow uint32 `json:"sliding_window"`
|
||||||
|
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
gemma4BLayerCount = 34
|
||||||
|
gemma12BLayerCount = 48
|
||||||
|
gemma27BLayerCount = 62
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
|
||||||
|
kv := p.ModelParameters.KV(t)
|
||||||
|
kv["general.architecture"] = "gemma3"
|
||||||
|
|
||||||
|
numBlocks := cmp.Or(p.HiddenLayers, p.TextModel.HiddenLayers)
|
||||||
|
kv["gemma3.block_count"] = numBlocks
|
||||||
|
|
||||||
|
var (
|
||||||
|
numHeads uint32
|
||||||
|
numKVHeads uint32
|
||||||
|
)
|
||||||
|
|
||||||
|
switch numBlocks {
|
||||||
|
case gemma4BLayerCount:
|
||||||
|
numHeads = 8
|
||||||
|
numKVHeads = 4
|
||||||
|
case gemma12BLayerCount:
|
||||||
|
numHeads = 16
|
||||||
|
numKVHeads = 8
|
||||||
|
case gemma27BLayerCount:
|
||||||
|
numHeads = 32
|
||||||
|
numKVHeads = 16
|
||||||
|
default:
|
||||||
|
numHeads = p.NumAttentionHeads
|
||||||
|
numKVHeads = p.NumKeyValueHeads
|
||||||
|
}
|
||||||
|
|
||||||
|
kv["gemma3.attention.head_count"] = numHeads
|
||||||
|
kv["gemma3.attention.head_count_kv"] = numKVHeads
|
||||||
|
|
||||||
|
switch p.Architecture {
|
||||||
|
case "Gemma3ForCausalLM":
|
||||||
|
kv["gemma3.context_length"] = p.MaxPositionEmbeddings
|
||||||
|
kv["gemma3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||||
|
kv["gemma3.attention.key_length"] = p.HeadDim
|
||||||
|
kv["gemma3.attention.value_length"] = p.HeadDim
|
||||||
|
kv["gemma3.attention.sliding_window"] = p.SlidingWindow
|
||||||
|
kv["gemma3.final_logit_softcapping"] = cmp.Or(p.FinalLogitSoftcap, 30)
|
||||||
|
kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0)
|
||||||
|
kv["gemma3.rope.global.freq_base"] = cmp.Or(p.RopeGlobalTheta, 1000000.0)
|
||||||
|
kv["gemma3.embedding_length"] = p.HiddenSize
|
||||||
|
kv["gemma3.feed_forward_length"] = p.IntermediateSize
|
||||||
|
default:
|
||||||
|
kv["gemma3.context_length"] = cmp.Or(p.MaxPositionEmbeddings, 131072)
|
||||||
|
kv["gemma3.embedding_length"] = p.TextModel.HiddenSize
|
||||||
|
kv["gemma3.feed_forward_length"] = p.TextModel.IntermediateSize
|
||||||
|
kv["gemma3.attention.sliding_window"] = p.TextModel.SlidingWindow
|
||||||
|
kv["gemma3.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
||||||
|
kv["gemma3.vision.embedding_length"] = p.VisionModel.HiddenSize
|
||||||
|
kv["gemma3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
|
||||||
|
kv["gemma3.vision.image_size"] = p.VisionModel.ImageSize
|
||||||
|
kv["gemma3.vision.patch_size"] = p.VisionModel.PatchSize
|
||||||
|
kv["gemma3.vision.num_channels"] = cmp.Or(p.VisionModel.NumChannels, 3)
|
||||||
|
kv["gemma3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
|
||||||
|
kv["gemma3.vision.attention.layer_norm_epsilon"] = cmp.Or(p.VisionModel.LayerNormEpsilon, 1e-6)
|
||||||
|
kv["gemma3.attention.key_length"] = cmp.Or(p.TextModel.HeadDim, 256)
|
||||||
|
kv["gemma3.attention.value_length"] = cmp.Or(p.TextModel.HeadDim, 256)
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.MultiModalTokensPerImage > 0 {
|
||||||
|
kv["gemma3.mm.tokens_per_image"] = p.MultiModalTokensPerImage
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *gemma3Model) Replacements() []string {
|
||||||
|
return []string{
|
||||||
|
"lm_head", "output",
|
||||||
|
"model.embed_tokens", "token_embd",
|
||||||
|
"model.norm", "output_norm",
|
||||||
|
"vision_tower.vision_model.embeddings", "v",
|
||||||
|
"vision_tower.vision_model", "v",
|
||||||
|
"vision_model.vision_model.embeddings", "v",
|
||||||
|
"vision_model.vision_model", "v",
|
||||||
|
"language_model.", "",
|
||||||
|
"model.layers", "blk",
|
||||||
|
"encoder.layers", "blk",
|
||||||
|
"input_layernorm", "attn_norm",
|
||||||
|
"self_attn.q_proj", "attn_q",
|
||||||
|
"self_attn.q_norm", "attn_q_norm",
|
||||||
|
"self_attn.k_proj", "attn_k",
|
||||||
|
"self_attn.k_norm", "attn_k_norm",
|
||||||
|
"self_attn.v_proj", "attn_v",
|
||||||
|
"self_attn.o_proj", "attn_output",
|
||||||
|
"self_attn.out_proj", "attn_output",
|
||||||
|
"mlp.gate_proj", "ffn_gate",
|
||||||
|
"mlp.down_proj", "ffn_down",
|
||||||
|
"mlp.up_proj", "ffn_up",
|
||||||
|
"post_attention_layernorm", "post_attention_norm",
|
||||||
|
"pre_feedforward_layernorm", "ffn_norm",
|
||||||
|
"post_feedforward_layernorm", "post_ffw_norm",
|
||||||
|
"input_projection_weight", "input_projection.weight",
|
||||||
|
"multi_modal_projector", "mm",
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
"github.com/pdevine/tensor"
|
"github.com/pdevine/tensor"
|
||||||
"github.com/pdevine/tensor/native"
|
"github.com/pdevine/tensor/native"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type llamaModel struct {
|
type llamaModel struct {
|
||||||
@@ -33,7 +33,7 @@ type llamaModel struct {
|
|||||||
Factor float32 `json:"factor"`
|
Factor float32 `json:"factor"`
|
||||||
LowFrequencyFactor float32 `json:"low_freq_factor"`
|
LowFrequencyFactor float32 `json:"low_freq_factor"`
|
||||||
HighFrequencyFactor float32 `json:"high_freq_factor"`
|
HighFrequencyFactor float32 `json:"high_freq_factor"`
|
||||||
OriginalMaxPositionalEmbeddings uint32 `json:"original_max_positional_embeddings"`
|
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||||
|
|
||||||
factors ropeFactor
|
factors ropeFactor
|
||||||
} `json:"rope_scaling"`
|
} `json:"rope_scaling"`
|
||||||
@@ -42,11 +42,13 @@ type llamaModel struct {
|
|||||||
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
||||||
NormEpsilon float32 `json:"norm_epsilon"`
|
NormEpsilon float32 `json:"norm_epsilon"`
|
||||||
HeadDim uint32 `json:"head_dim"`
|
HeadDim uint32 `json:"head_dim"`
|
||||||
|
|
||||||
|
skipRepack bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ ModelConverter = (*llamaModel)(nil)
|
var _ ModelConverter = (*llamaModel)(nil)
|
||||||
|
|
||||||
func (p *llamaModel) KV(t *Tokenizer) llm.KV {
|
func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
|
||||||
kv := p.ModelParameters.KV(t)
|
kv := p.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "llama"
|
kv["general.architecture"] = "llama"
|
||||||
kv["llama.vocab_size"] = p.VocabSize
|
kv["llama.vocab_size"] = p.VocabSize
|
||||||
@@ -70,6 +72,10 @@ func (p *llamaModel) KV(t *Tokenizer) llm.KV {
|
|||||||
kv["llama.rope.dimension_count"] = p.HiddenSize / headCount
|
kv["llama.rope.dimension_count"] = p.HiddenSize / headCount
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if p.HeadDim > 0 {
|
||||||
|
kv["llama.attention.head_dim"] = p.HeadDim
|
||||||
|
}
|
||||||
|
|
||||||
if p.RopeTheta > 0 {
|
if p.RopeTheta > 0 {
|
||||||
kv["llama.rope.freq_base"] = p.RopeTheta
|
kv["llama.rope.freq_base"] = p.RopeTheta
|
||||||
}
|
}
|
||||||
@@ -84,7 +90,7 @@ func (p *llamaModel) KV(t *Tokenizer) llm.KV {
|
|||||||
factorLow := cmp.Or(p.RopeScaling.LowFrequencyFactor, 1.0)
|
factorLow := cmp.Or(p.RopeScaling.LowFrequencyFactor, 1.0)
|
||||||
factorHigh := cmp.Or(p.RopeScaling.HighFrequencyFactor, 4.0)
|
factorHigh := cmp.Or(p.RopeScaling.HighFrequencyFactor, 4.0)
|
||||||
|
|
||||||
original := cmp.Or(p.RopeScaling.OriginalMaxPositionalEmbeddings, 8192)
|
original := cmp.Or(p.RopeScaling.OriginalMaxPositionEmbeddings, 8192)
|
||||||
lambdaLow := float32(original) / factorLow
|
lambdaLow := float32(original) / factorLow
|
||||||
lambdaHigh := float32(original) / factorHigh
|
lambdaHigh := float32(original) / factorHigh
|
||||||
|
|
||||||
@@ -120,11 +126,11 @@ func (p *llamaModel) KV(t *Tokenizer) llm.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *llamaModel) Tensors(ts []Tensor) []llm.Tensor {
|
func (p *llamaModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||||
var out []llm.Tensor
|
var out []ggml.Tensor
|
||||||
|
|
||||||
if p.RopeScaling.factors != nil {
|
if p.RopeScaling.factors != nil {
|
||||||
out = append(out, llm.Tensor{
|
out = append(out, ggml.Tensor{
|
||||||
Name: "rope_freqs.weight",
|
Name: "rope_freqs.weight",
|
||||||
Kind: 0,
|
Kind: 0,
|
||||||
Shape: []uint64{uint64(len(p.RopeScaling.factors))},
|
Shape: []uint64{uint64(len(p.RopeScaling.factors))},
|
||||||
@@ -133,12 +139,13 @@ func (p *llamaModel) Tensors(ts []Tensor) []llm.Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
if strings.HasSuffix(t.Name(), "attn_q.weight") ||
|
if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") {
|
||||||
strings.HasSuffix(t.Name(), "attn_k.weight") {
|
if !p.skipRepack {
|
||||||
t.SetRepacker(p.repack)
|
t.SetRepacker(p.repack)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
out = append(out, llm.Tensor{
|
out = append(out, ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: t.Shape(),
|
Shape: t.Shape(),
|
||||||
|
|||||||
169
convert/convert_llama4.go
Normal file
169
convert/convert_llama4.go
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pdevine/tensor"
|
||||||
|
"github.com/pdevine/tensor/native"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type llama4Model struct {
|
||||||
|
ModelParameters
|
||||||
|
TextModel struct {
|
||||||
|
llamaModel
|
||||||
|
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||||
|
NumLocalExperts uint32 `json:"num_local_experts"`
|
||||||
|
InterleaveMOELayerStep uint32 `json:"interleave_moe_layer_step"`
|
||||||
|
UseQKNorm bool `json:"use_qk_norm"`
|
||||||
|
IntermediateSizeMLP uint32 `json:"intermediate_size_mlp"`
|
||||||
|
AttentionChunkSize uint32 `json:"attention_chunk_size"`
|
||||||
|
} `json:"text_config"`
|
||||||
|
VisionModel struct {
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
ImageSize uint32 `json:"image_size"`
|
||||||
|
PatchSize uint32 `json:"patch_size"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
NormEpsilon float32 `json:"norm_eps"`
|
||||||
|
PixelShuffleRatio float32 `json:"pixel_shuffle_ratio"`
|
||||||
|
} `json:"vision_config"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KV implements ModelConverter.
|
||||||
|
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
|
||||||
|
kv := p.ModelParameters.KV(t)
|
||||||
|
kv["general.architecture"] = "llama4"
|
||||||
|
|
||||||
|
for k, v := range p.TextModel.KV(t) {
|
||||||
|
if strings.HasPrefix(k, "llama.") {
|
||||||
|
kv[strings.ReplaceAll(k, "llama.", "llama4.")] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kv["llama4.feed_forward_length"] = p.TextModel.IntermediateSizeMLP
|
||||||
|
kv["llama4.expert_feed_forward_length"] = p.TextModel.IntermediateSize
|
||||||
|
|
||||||
|
kv["llama4.expert_count"] = p.TextModel.NumLocalExperts
|
||||||
|
kv["llama4.expert_used_count"] = p.TextModel.NumExpertsPerToken
|
||||||
|
kv["llama4.interleave_moe_layer_step"] = p.TextModel.InterleaveMOELayerStep
|
||||||
|
kv["llama4.use_qk_norm"] = p.TextModel.UseQKNorm
|
||||||
|
kv["llama4.attention.chunk_size"] = p.TextModel.AttentionChunkSize
|
||||||
|
|
||||||
|
kv["llama4.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
||||||
|
kv["llama4.vision.embedding_length"] = p.VisionModel.HiddenSize
|
||||||
|
kv["llama4.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
|
||||||
|
kv["llama4.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
|
||||||
|
kv["llama4.vision.image_size"] = p.VisionModel.ImageSize
|
||||||
|
kv["llama4.vision.patch_size"] = p.VisionModel.PatchSize
|
||||||
|
kv["llama4.vision.rope.freq_base"] = p.VisionModel.RopeTheta
|
||||||
|
kv["llama4.vision.layer_norm_epsilon"] = p.VisionModel.NormEpsilon
|
||||||
|
kv["llama4.vision.pixel_shuffle_ratio"] = p.VisionModel.PixelShuffleRatio
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replacements implements ModelConverter.
|
||||||
|
func (p *llama4Model) Replacements() []string {
|
||||||
|
return append(
|
||||||
|
p.TextModel.Replacements(),
|
||||||
|
"language_model.", "",
|
||||||
|
"vision_model", "v",
|
||||||
|
"multi_modal_projector", "mm",
|
||||||
|
"feed_forward.down_proj", "ffn_down",
|
||||||
|
"feed_forward.up_proj", "ffn_up",
|
||||||
|
"feed_forward.gate_proj", "ffn_gate",
|
||||||
|
"feed_forward.", "ffn_",
|
||||||
|
"shared_expert.down_proj", "down_shexp",
|
||||||
|
"shared_expert.gate_proj", "gate_shexp",
|
||||||
|
"shared_expert.up_proj", "up_shexp",
|
||||||
|
"experts.down_proj", "down_exps.weight",
|
||||||
|
"experts.gate_up_proj", "gate_up_exps.weight",
|
||||||
|
"router", "gate_inp",
|
||||||
|
"patch_embedding.linear", "patch_embedding",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tensors implements ModelConverter.
|
||||||
|
func (p *llama4Model) Tensors(ts []Tensor) []ggml.Tensor {
|
||||||
|
var out []ggml.Tensor
|
||||||
|
|
||||||
|
var textTensors []Tensor
|
||||||
|
for _, t := range ts {
|
||||||
|
if strings.HasPrefix(t.Name(), "v.") || strings.HasPrefix(t.Name(), "mm.") {
|
||||||
|
out = append(out, ggml.Tensor{
|
||||||
|
Name: t.Name(),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: t.Shape(),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
} else if strings.Contains(t.Name(), "ffn_gate_up_exps") {
|
||||||
|
// gate and up projectors are fused
|
||||||
|
// dims[1], dims[2] must be swapped
|
||||||
|
// [experts, hidden_size, intermediate_size * 2] --> [experts, intermediate_size, hidden_size]
|
||||||
|
halfDim := int(t.Shape()[2]) / 2
|
||||||
|
|
||||||
|
newShape := slices.Clone(t.Shape())
|
||||||
|
newShape[1], newShape[2] = newShape[2]/2, newShape[1]
|
||||||
|
for i, name := range []string{"ffn_gate_exps", "ffn_up_exps"} {
|
||||||
|
// clone tensor since we need separate repackers
|
||||||
|
tt := t.Clone()
|
||||||
|
tt.SetRepacker(p.repack(nil, nil, tensor.S(i*halfDim, (i+1)*halfDim)))
|
||||||
|
out = append(out, ggml.Tensor{
|
||||||
|
Name: strings.ReplaceAll(tt.Name(), "ffn_gate_up_exps", name),
|
||||||
|
Kind: tt.Kind(),
|
||||||
|
Shape: newShape,
|
||||||
|
WriterTo: tt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else if strings.Contains(t.Name(), "ffn_down_exps") {
|
||||||
|
// dims[1], dims[2] must be swapped
|
||||||
|
// [experts, intermediate_size, hidden_size] --> [experts, hidden_size, intermediate_size]
|
||||||
|
t.SetRepacker(p.repack())
|
||||||
|
newShape := slices.Clone(t.Shape())
|
||||||
|
newShape[1], newShape[2] = newShape[2], newShape[1]
|
||||||
|
out = append(out, ggml.Tensor{
|
||||||
|
Name: t.Name(),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: newShape,
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
textTensors = append(textTensors, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.TextModel.skipRepack = true
|
||||||
|
out = append(out, p.TextModel.Tensors(textTensors)...)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *llama4Model) repack(slice ...tensor.Slice) Repacker {
|
||||||
|
return func(name string, data []float32, shape []uint64) ([]float32, error) {
|
||||||
|
dims := make([]int, len(shape))
|
||||||
|
for i, dim := range shape {
|
||||||
|
dims[i] = int(dim)
|
||||||
|
}
|
||||||
|
|
||||||
|
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||||
|
t, err := t.Slice(slice...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := t.T(0, 2, 1); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
t = tensor.Materialize(t)
|
||||||
|
// flatten tensor so it can be return as a vector
|
||||||
|
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return native.VectorF32(t.(*tensor.Dense))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
"github.com/pdevine/tensor"
|
"github.com/pdevine/tensor"
|
||||||
"github.com/pdevine/tensor/native"
|
"github.com/pdevine/tensor/native"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type llamaAdapter struct {
|
type llamaAdapter struct {
|
||||||
@@ -18,7 +18,7 @@ type llamaAdapter struct {
|
|||||||
|
|
||||||
var _ AdapterConverter = (*llamaAdapter)(nil)
|
var _ AdapterConverter = (*llamaAdapter)(nil)
|
||||||
|
|
||||||
func (p *llamaAdapter) KV(baseKV llm.KV) llm.KV {
|
func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
|
||||||
kv := p.AdapterParameters.KV()
|
kv := p.AdapterParameters.KV()
|
||||||
kv["general.architecture"] = "llama"
|
kv["general.architecture"] = "llama"
|
||||||
kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"]
|
kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"]
|
||||||
@@ -29,8 +29,8 @@ func (p *llamaAdapter) KV(baseKV llm.KV) llm.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *llamaAdapter) Tensors(ts []Tensor) []llm.Tensor {
|
func (p *llamaAdapter) Tensors(ts []Tensor) []ggml.Tensor {
|
||||||
var out []llm.Tensor
|
var out []ggml.Tensor
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
shape := t.Shape()
|
shape := t.Shape()
|
||||||
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
|
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
|
||||||
@@ -41,7 +41,7 @@ func (p *llamaAdapter) Tensors(ts []Tensor) []llm.Tensor {
|
|||||||
t.SetRepacker(p.repack)
|
t.SetRepacker(p.repack)
|
||||||
}
|
}
|
||||||
|
|
||||||
out = append(out, llm.Tensor{
|
out = append(out, ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: shape,
|
Shape: shape,
|
||||||
|
|||||||
190
convert/convert_mistral.go
Normal file
190
convert/convert_mistral.go
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pdevine/tensor"
|
||||||
|
"github.com/pdevine/tensor/native"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mistral3Model struct {
|
||||||
|
ModelParameters
|
||||||
|
ImageTokenIndex uint32 `json:"image_token_index"`
|
||||||
|
SpatialMergeSize uint32 `json:"spatial_merge_size"`
|
||||||
|
VisionFeatureLayer int32 `json:"vision_feature_layer"`
|
||||||
|
TextModel struct {
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||||
|
HeadDim uint32 `json:"head_dim"`
|
||||||
|
SlidingWindow *uint32 `json:"sliding_window"`
|
||||||
|
HiddenAct string `json:"hidden_act"`
|
||||||
|
VocabSize uint32 `json:"vocab_size"`
|
||||||
|
} `json:"text_config"`
|
||||||
|
VisionModel struct {
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
ImageSize uint32 `json:"image_size"`
|
||||||
|
NumChannels uint32 `json:"num_channels"`
|
||||||
|
PatchSize uint32 `json:"patch_size"`
|
||||||
|
HeadDim uint32 `json:"head_dim"`
|
||||||
|
HiddenAct string `json:"hidden_act"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
} `json:"vision_config"`
|
||||||
|
MultiModalProjectorBias bool `json:"multimodal_projector_bias"`
|
||||||
|
ProjectorHiddenAct string `json:"projector_hidden_act"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
||||||
|
kv := p.ModelParameters.KV(t)
|
||||||
|
kv["general.architecture"] = "mistral3"
|
||||||
|
kv["mistral3.vocab_size"] = p.TextModel.VocabSize
|
||||||
|
|
||||||
|
// Text configuration
|
||||||
|
kv["mistral3.block_count"] = p.TextModel.NumHiddenLayers
|
||||||
|
kv["mistral3.context_length"] = p.TextModel.MaxPositionEmbeddings
|
||||||
|
kv["mistral3.embedding_length"] = p.TextModel.HiddenSize
|
||||||
|
kv["mistral3.feed_forward_length"] = p.TextModel.IntermediateSize
|
||||||
|
kv["mistral3.attention.head_count"] = p.TextModel.NumAttentionHeads
|
||||||
|
kv["mistral3.attention.head_count_kv"] = p.TextModel.NumKeyValueHeads
|
||||||
|
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
|
||||||
|
kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
|
||||||
|
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
|
||||||
|
kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
|
||||||
|
kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta
|
||||||
|
|
||||||
|
// Vision configuration
|
||||||
|
kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
||||||
|
kv["mistral3.vision.embedding_length"] = p.VisionModel.HiddenSize
|
||||||
|
kv["mistral3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
|
||||||
|
kv["mistral3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
|
||||||
|
kv["mistral3.vision.attention.key_length"] = p.VisionModel.HeadDim
|
||||||
|
kv["mistral3.vision.image_size"] = p.VisionModel.ImageSize
|
||||||
|
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
|
||||||
|
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
|
||||||
|
// kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value
|
||||||
|
kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta
|
||||||
|
|
||||||
|
// Multimodal configuration
|
||||||
|
kv["mistral3.image_token_index"] = p.ImageTokenIndex
|
||||||
|
kv["mistral3.spatial_merge_size"] = p.SpatialMergeSize
|
||||||
|
|
||||||
|
kv["mistral3.mm.projector_bias"] = p.MultiModalProjectorBias
|
||||||
|
|
||||||
|
if p.ProjectorHiddenAct != "" {
|
||||||
|
kv["mistral3.mm.projector_hidden_act"] = p.ProjectorHiddenAct
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistral3Model) Tensors(ts []Tensor) []ggml.Tensor {
|
||||||
|
var out []ggml.Tensor
|
||||||
|
|
||||||
|
for _, t := range ts {
|
||||||
|
if !strings.HasPrefix(t.Name(), "v.") {
|
||||||
|
if strings.HasSuffix(t.Name(), ".attn_q.weight") ||
|
||||||
|
strings.HasSuffix(t.Name(), ".attn_k.weight") {
|
||||||
|
t.SetRepacker(p.repack)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out = append(out, ggml.Tensor{
|
||||||
|
Name: t.Name(),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: t.Shape(),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistral3Model) Replacements() []string {
|
||||||
|
return []string{
|
||||||
|
"language_model.model.norm", "output_norm",
|
||||||
|
"language_model.model.", "",
|
||||||
|
"language_model.", "",
|
||||||
|
"layers", "blk",
|
||||||
|
"transformer.layers", "blk",
|
||||||
|
"vision_tower", "v",
|
||||||
|
"ln_pre", "encoder_norm",
|
||||||
|
"input_layernorm", "attn_norm",
|
||||||
|
"post_attention_layernorm", "ffn_norm",
|
||||||
|
"embed_tokens", "token_embd",
|
||||||
|
"self_attn.q_proj", "attn_q",
|
||||||
|
"self_attn.k_proj", "attn_k",
|
||||||
|
"self_attn.v_proj", "attn_v",
|
||||||
|
"self_attn.o_proj", "attn_output",
|
||||||
|
"mlp.down_proj", "ffn_down",
|
||||||
|
"mlp.gate_proj", "ffn_gate",
|
||||||
|
"mlp.up_proj", "ffn_up",
|
||||||
|
"attention.q_proj", "attn_q",
|
||||||
|
"attention.k_proj", "attn_k",
|
||||||
|
"attention.v_proj", "attn_v",
|
||||||
|
"attention.o_proj", "attn_output",
|
||||||
|
"attention_norm", "attn_norm",
|
||||||
|
"feed_forward.gate_proj", "ffn_gate",
|
||||||
|
"feed_forward.down_proj", "ffn_down",
|
||||||
|
"feed_forward.up_proj", "ffn_up",
|
||||||
|
"multi_modal_projector", "mm",
|
||||||
|
"ffn_norm", "ffn_norm",
|
||||||
|
"lm_head", "output",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistral3Model) repack(name string, data []float32, shape []uint64) ([]float32, error) {
|
||||||
|
var dims []int
|
||||||
|
for _, dim := range shape {
|
||||||
|
dims = append(dims, int(dim))
|
||||||
|
}
|
||||||
|
|
||||||
|
var heads uint32
|
||||||
|
if strings.HasSuffix(name, ".attn_q.weight") {
|
||||||
|
heads = p.TextModel.NumAttentionHeads
|
||||||
|
} else if strings.HasSuffix(name, ".attn_k.weight") {
|
||||||
|
heads = cmp.Or(p.TextModel.NumKeyValueHeads, p.TextModel.NumAttentionHeads)
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||||
|
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := n.T(0, 2, 1, 3); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := n.Reshape(dims...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := n.Transpose(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ts, err := native.SelectF32(n, 1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var f32s []float32
|
||||||
|
for _, t := range ts {
|
||||||
|
f32s = append(f32s, t...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return f32s, nil
|
||||||
|
}
|
||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mixtralModel struct {
|
type mixtralModel struct {
|
||||||
@@ -15,7 +15,7 @@ type mixtralModel struct {
|
|||||||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *mixtralModel) KV(t *Tokenizer) llm.KV {
|
func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
|
||||||
kv := p.llamaModel.KV(t)
|
kv := p.llamaModel.KV(t)
|
||||||
|
|
||||||
if p.NumLocalExperts > 0 {
|
if p.NumLocalExperts > 0 {
|
||||||
@@ -29,7 +29,7 @@ func (p *mixtralModel) KV(t *Tokenizer) llm.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *mixtralModel) Tensors(ts []Tensor) []llm.Tensor {
|
func (p *mixtralModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||||
oldnew := []string{
|
oldnew := []string{
|
||||||
"model.layers", "blk",
|
"model.layers", "blk",
|
||||||
"w1", "ffn_gate_exps",
|
"w1", "ffn_gate_exps",
|
||||||
@@ -56,10 +56,10 @@ func (p *mixtralModel) Tensors(ts []Tensor) []llm.Tensor {
|
|||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
var out []llm.Tensor
|
var out []ggml.Tensor
|
||||||
for n, e := range experts {
|
for n, e := range experts {
|
||||||
// TODO(mxyng): sanity check experts
|
// TODO(mxyng): sanity check experts
|
||||||
out = append(out, llm.Tensor{
|
out = append(out, ggml.Tensor{
|
||||||
Name: n,
|
Name: n,
|
||||||
Kind: e[0].Kind(),
|
Kind: e[0].Kind(),
|
||||||
Shape: append([]uint64{uint64(len(e))}, e[0].Shape()...),
|
Shape: append([]uint64{uint64(len(e))}, e[0].Shape()...),
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type phi3Model struct {
|
type phi3Model struct {
|
||||||
@@ -37,7 +37,7 @@ type phi3Model struct {
|
|||||||
|
|
||||||
var _ ModelConverter = (*phi3Model)(nil)
|
var _ ModelConverter = (*phi3Model)(nil)
|
||||||
|
|
||||||
func (p *phi3Model) KV(t *Tokenizer) llm.KV {
|
func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
|
||||||
kv := p.ModelParameters.KV(t)
|
kv := p.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "phi3"
|
kv["general.architecture"] = "phi3"
|
||||||
kv["phi3.context_length"] = p.MaxPositionEmbeddings
|
kv["phi3.context_length"] = p.MaxPositionEmbeddings
|
||||||
@@ -68,19 +68,19 @@ func (p *phi3Model) KV(t *Tokenizer) llm.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *phi3Model) Tensors(ts []Tensor) []llm.Tensor {
|
func (p *phi3Model) Tensors(ts []Tensor) []ggml.Tensor {
|
||||||
var addRopeFactors sync.Once
|
var addRopeFactors sync.Once
|
||||||
|
|
||||||
out := make([]llm.Tensor, 0, len(ts)+2)
|
out := make([]ggml.Tensor, 0, len(ts)+2)
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
if strings.HasPrefix(t.Name(), "blk.0.") {
|
if strings.HasPrefix(t.Name(), "blk.0.") {
|
||||||
addRopeFactors.Do(func() {
|
addRopeFactors.Do(func() {
|
||||||
out = append(out, llm.Tensor{
|
out = append(out, ggml.Tensor{
|
||||||
Name: "rope_factors_long.weight",
|
Name: "rope_factors_long.weight",
|
||||||
Kind: 0,
|
Kind: 0,
|
||||||
Shape: []uint64{uint64(len(p.RopeScaling.LongFactor))},
|
Shape: []uint64{uint64(len(p.RopeScaling.LongFactor))},
|
||||||
WriterTo: p.RopeScaling.LongFactor,
|
WriterTo: p.RopeScaling.LongFactor,
|
||||||
}, llm.Tensor{
|
}, ggml.Tensor{
|
||||||
Name: "rope_factors_short.weight",
|
Name: "rope_factors_short.weight",
|
||||||
Kind: 0,
|
Kind: 0,
|
||||||
Shape: []uint64{uint64(len(p.RopeScaling.ShortFactor))},
|
Shape: []uint64{uint64(len(p.RopeScaling.ShortFactor))},
|
||||||
@@ -89,7 +89,7 @@ func (p *phi3Model) Tensors(ts []Tensor) []llm.Tensor {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
out = append(out, llm.Tensor{
|
out = append(out, ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: t.Shape(),
|
Shape: t.Shape(),
|
||||||
@@ -118,6 +118,5 @@ func (p *phi3Model) Replacements() []string {
|
|||||||
type ropeFactor []float32
|
type ropeFactor []float32
|
||||||
|
|
||||||
func (r ropeFactor) WriteTo(w io.Writer) (int64, error) {
|
func (r ropeFactor) WriteTo(w io.Writer) (int64, error) {
|
||||||
err := binary.Write(w, binary.LittleEndian, r)
|
return 0, binary.Write(w, binary.LittleEndian, r)
|
||||||
return 0, err
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package convert
|
package convert
|
||||||
|
|
||||||
import "github.com/ollama/ollama/llm"
|
import "github.com/ollama/ollama/fs/ggml"
|
||||||
|
|
||||||
type qwen2Model struct {
|
type qwen2Model struct {
|
||||||
ModelParameters
|
ModelParameters
|
||||||
@@ -21,7 +21,7 @@ type qwen2Model struct {
|
|||||||
|
|
||||||
var _ ModelConverter = (*qwen2Model)(nil)
|
var _ ModelConverter = (*qwen2Model)(nil)
|
||||||
|
|
||||||
func (q *qwen2Model) KV(t *Tokenizer) llm.KV {
|
func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
|
||||||
kv := q.ModelParameters.KV(t)
|
kv := q.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "qwen2"
|
kv["general.architecture"] = "qwen2"
|
||||||
kv["qwen2.block_count"] = q.HiddenLayers
|
kv["qwen2.block_count"] = q.HiddenLayers
|
||||||
@@ -45,10 +45,10 @@ func (q *qwen2Model) KV(t *Tokenizer) llm.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *qwen2Model) Tensors(ts []Tensor) []llm.Tensor {
|
func (q *qwen2Model) Tensors(ts []Tensor) []ggml.Tensor {
|
||||||
var out []llm.Tensor
|
var out []ggml.Tensor
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
out = append(out, llm.Tensor{
|
out = append(out, ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: t.Shape(),
|
Shape: t.Shape(),
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -20,7 +19,7 @@ import (
|
|||||||
|
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
type tensorData struct {
|
type tensorData struct {
|
||||||
@@ -29,7 +28,7 @@ type tensorData struct {
|
|||||||
Shape []int `json:"shape"`
|
Shape []int `json:"shape"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, *llm.Tensors) {
|
func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
f, err := os.CreateTemp(t.TempDir(), "f16")
|
f, err := os.CreateTemp(t.TempDir(), "f16")
|
||||||
@@ -48,7 +47,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, *llm.Tensors) {
|
|||||||
}
|
}
|
||||||
t.Cleanup(func() { r.Close() })
|
t.Cleanup(func() { r.Close() })
|
||||||
|
|
||||||
m, _, err := llm.DecodeGGML(r, math.MaxInt)
|
m, _, err := ggml.Decode(r, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -60,7 +59,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, *llm.Tensors) {
|
|||||||
return r, m.KV(), m.Tensors()
|
return r, m.KV(), m.Tensors()
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateResultsJSON(t *testing.T, f *os.File, kv llm.KV, tensors *llm.Tensors) map[string]string {
|
func generateResultsJSON(t *testing.T, f *os.File, kv ggml.KV, tensors ggml.Tensors) map[string]string {
|
||||||
actual := make(map[string]string)
|
actual := make(map[string]string)
|
||||||
for k, v := range kv {
|
for k, v := range kv {
|
||||||
if s, ok := v.(json.Marshaler); !ok {
|
if s, ok := v.(json.Marshaler); !ok {
|
||||||
@@ -75,7 +74,7 @@ func generateResultsJSON(t *testing.T, f *os.File, kv llm.KV, tensors *llm.Tenso
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tensor := range tensors.Items {
|
for _, tensor := range tensors.Items() {
|
||||||
sha256sum := sha256.New()
|
sha256sum := sha256.New()
|
||||||
sr := io.NewSectionReader(f, int64(tensors.Offset+tensor.Offset), int64(tensor.Size()))
|
sr := io.NewSectionReader(f, int64(tensors.Offset+tensor.Offset), int64(tensor.Size()))
|
||||||
if _, err := io.Copy(sha256sum, sr); err != nil {
|
if _, err := io.Copy(sha256sum, sr); err != nil {
|
||||||
@@ -332,7 +331,7 @@ func TestConvertAdapter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer r.Close()
|
defer r.Close()
|
||||||
|
|
||||||
m, _, err := llm.DecodeGGML(r, math.MaxInt)
|
m, _, err := ggml.Decode(r, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,14 +11,15 @@ type Tensor interface {
|
|||||||
Name() string
|
Name() string
|
||||||
Shape() []uint64
|
Shape() []uint64
|
||||||
Kind() uint32
|
Kind() uint32
|
||||||
SetRepacker(repacker)
|
SetRepacker(Repacker)
|
||||||
WriteTo(io.Writer) (int64, error)
|
WriteTo(io.Writer) (int64, error)
|
||||||
|
Clone() Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
type tensorBase struct {
|
type tensorBase struct {
|
||||||
name string
|
name string
|
||||||
shape []uint64
|
shape []uint64
|
||||||
repacker
|
repacker Repacker
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t tensorBase) Name() string {
|
func (t tensorBase) Name() string {
|
||||||
@@ -36,7 +37,8 @@ const (
|
|||||||
|
|
||||||
func (t tensorBase) Kind() uint32 {
|
func (t tensorBase) Kind() uint32 {
|
||||||
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
|
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
|
||||||
t.name == "token_types.weight" {
|
t.name == "token_types.weight" ||
|
||||||
|
t.name == "v.positional_embedding_vlm" {
|
||||||
// these tensors are always F32
|
// these tensors are always F32
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
@@ -51,21 +53,18 @@ func (t tensorBase) Kind() uint32 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tensorBase) SetRepacker(fn repacker) {
|
func (t *tensorBase) SetRepacker(fn Repacker) {
|
||||||
t.repacker = fn
|
t.repacker = fn
|
||||||
}
|
}
|
||||||
|
|
||||||
type repacker func(string, []float32, []uint64) ([]float32, error)
|
type Repacker func(string, []float32, []uint64) ([]float32, error)
|
||||||
|
|
||||||
func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) {
|
func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) {
|
||||||
patterns := []struct {
|
patterns := []struct {
|
||||||
Pattern string
|
Pattern string
|
||||||
Func func(fs.FS, *strings.Replacer, ...string) ([]Tensor, error)
|
Func func(fs.FS, *strings.Replacer, ...string) ([]Tensor, error)
|
||||||
}{
|
}{
|
||||||
{"model-*-of-*.safetensors", parseSafetensors},
|
{"*.safetensors", parseSafetensors},
|
||||||
{"model.safetensors", parseSafetensors},
|
|
||||||
{"adapters.safetensors", parseSafetensors},
|
|
||||||
{"adapter_model.safetensors", parseSafetensors},
|
|
||||||
{"pytorch_model-*-of-*.bin", parseTorch},
|
{"pytorch_model-*-of-*.bin", parseTorch},
|
||||||
{"pytorch_model.bin", parseTorch},
|
{"pytorch_model.bin", parseTorch},
|
||||||
{"consolidated.*.pth", parseTorch},
|
{"consolidated.*.pth", parseTorch},
|
||||||
|
|||||||
@@ -94,6 +94,21 @@ type safetensor struct {
|
|||||||
*tensorBase
|
*tensorBase
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (st safetensor) Clone() Tensor {
|
||||||
|
return &safetensor{
|
||||||
|
fs: st.fs,
|
||||||
|
path: st.path,
|
||||||
|
dtype: st.dtype,
|
||||||
|
offset: st.offset,
|
||||||
|
size: st.size,
|
||||||
|
tensorBase: &tensorBase{
|
||||||
|
name: st.name,
|
||||||
|
repacker: st.repacker,
|
||||||
|
shape: slices.Clone(st.shape),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
||||||
f, err := st.fs.Open(st.path)
|
f, err := st.fs.Open(st.path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -43,6 +43,17 @@ type torch struct {
|
|||||||
*tensorBase
|
*tensorBase
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t torch) Clone() Tensor {
|
||||||
|
return torch{
|
||||||
|
storage: t.storage,
|
||||||
|
tensorBase: &tensorBase{
|
||||||
|
name: t.name,
|
||||||
|
shape: t.shape,
|
||||||
|
repacker: t.repacker,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (pt torch) WriteTo(w io.Writer) (int64, error) {
|
func (pt torch) WriteTo(w io.Writer) (int64, error) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1360,7 +1360,7 @@ func file_sentencepiece_model_proto_rawDescGZIP() []byte {
|
|||||||
|
|
||||||
var file_sentencepiece_model_proto_enumTypes = make([]protoimpl.EnumInfo, 2)
|
var file_sentencepiece_model_proto_enumTypes = make([]protoimpl.EnumInfo, 2)
|
||||||
var file_sentencepiece_model_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
|
var file_sentencepiece_model_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
|
||||||
var file_sentencepiece_model_proto_goTypes = []interface{}{
|
var file_sentencepiece_model_proto_goTypes = []any{
|
||||||
(TrainerSpec_ModelType)(0), // 0: sentencepiece.TrainerSpec.ModelType
|
(TrainerSpec_ModelType)(0), // 0: sentencepiece.TrainerSpec.ModelType
|
||||||
(ModelProto_SentencePiece_Type)(0), // 1: sentencepiece.ModelProto.SentencePiece.Type
|
(ModelProto_SentencePiece_Type)(0), // 1: sentencepiece.ModelProto.SentencePiece.Type
|
||||||
(*TrainerSpec)(nil), // 2: sentencepiece.TrainerSpec
|
(*TrainerSpec)(nil), // 2: sentencepiece.TrainerSpec
|
||||||
@@ -1392,7 +1392,7 @@ func file_sentencepiece_model_proto_init() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !protoimpl.UnsafeEnabled {
|
if !protoimpl.UnsafeEnabled {
|
||||||
file_sentencepiece_model_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
file_sentencepiece_model_proto_msgTypes[0].Exporter = func(v any, i int) any {
|
||||||
switch v := v.(*TrainerSpec); i {
|
switch v := v.(*TrainerSpec); i {
|
||||||
case 0:
|
case 0:
|
||||||
return &v.state
|
return &v.state
|
||||||
@@ -1406,7 +1406,7 @@ func file_sentencepiece_model_proto_init() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
file_sentencepiece_model_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
|
file_sentencepiece_model_proto_msgTypes[1].Exporter = func(v any, i int) any {
|
||||||
switch v := v.(*NormalizerSpec); i {
|
switch v := v.(*NormalizerSpec); i {
|
||||||
case 0:
|
case 0:
|
||||||
return &v.state
|
return &v.state
|
||||||
@@ -1420,7 +1420,7 @@ func file_sentencepiece_model_proto_init() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
file_sentencepiece_model_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
|
file_sentencepiece_model_proto_msgTypes[2].Exporter = func(v any, i int) any {
|
||||||
switch v := v.(*SelfTestData); i {
|
switch v := v.(*SelfTestData); i {
|
||||||
case 0:
|
case 0:
|
||||||
return &v.state
|
return &v.state
|
||||||
@@ -1434,7 +1434,7 @@ func file_sentencepiece_model_proto_init() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
file_sentencepiece_model_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
|
file_sentencepiece_model_proto_msgTypes[3].Exporter = func(v any, i int) any {
|
||||||
switch v := v.(*ModelProto); i {
|
switch v := v.(*ModelProto); i {
|
||||||
case 0:
|
case 0:
|
||||||
return &v.state
|
return &v.state
|
||||||
@@ -1448,7 +1448,7 @@ func file_sentencepiece_model_proto_init() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
file_sentencepiece_model_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
|
file_sentencepiece_model_proto_msgTypes[4].Exporter = func(v any, i int) any {
|
||||||
switch v := v.(*SelfTestData_Sample); i {
|
switch v := v.(*SelfTestData_Sample); i {
|
||||||
case 0:
|
case 0:
|
||||||
return &v.state
|
return &v.state
|
||||||
@@ -1460,7 +1460,7 @@ func file_sentencepiece_model_proto_init() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
file_sentencepiece_model_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} {
|
file_sentencepiece_model_proto_msgTypes[5].Exporter = func(v any, i int) any {
|
||||||
switch v := v.(*ModelProto_SentencePiece); i {
|
switch v := v.(*ModelProto_SentencePiece); i {
|
||||||
case 0:
|
case 0:
|
||||||
return &v.state
|
return &v.state
|
||||||
|
|||||||
@@ -6,7 +6,9 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
|
"reflect"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
@@ -15,6 +17,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
||||||
|
slog.Debug("using spm vocabulary")
|
||||||
|
|
||||||
ast, err := parseAdditionalSpecialTokens(fsys)
|
ast, err := parseAdditionalSpecialTokens(fsys)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -43,10 +47,19 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
|||||||
v.Types = append(v.Types, int32(t))
|
v.Types = append(v.Types, int32(t))
|
||||||
default:
|
default:
|
||||||
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
|
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
|
||||||
if slices.Contains(ast, piece.GetPiece()) {
|
|
||||||
|
// temporary fix to handle gemma3 broken configs
|
||||||
|
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>"}, piece.GetPiece()) {
|
||||||
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
|
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, t := range ast {
|
||||||
|
if t.Content == piece.GetPiece() {
|
||||||
|
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
v.Types = append(v.Types, tt)
|
v.Types = append(v.Types, tt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -78,10 +91,16 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
|||||||
return cmp.Compare(i.id, j.id)
|
return cmp.Compare(i.id, j.id)
|
||||||
})
|
})
|
||||||
|
|
||||||
n := len(v.Tokens)
|
for _, t := range ts {
|
||||||
for i, t := range ts {
|
if t.id < len(v.Tokens) {
|
||||||
if t.id != i+n {
|
if v.Tokens[t.id] == t.content {
|
||||||
return nil, fmt.Errorf("invalid token id: %d", t.id)
|
slog.Warn("tokenizer", "duplicate token", t.content, "id", t.id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("token mismatch: %s != %s at pos [%d]", t.content, v.Tokens[t.id], t.id)
|
||||||
|
}
|
||||||
|
if t.id != len(v.Tokens) {
|
||||||
|
return nil, fmt.Errorf("invalid token id: [%d] as pos [%d]", t.id, len(v.Tokens))
|
||||||
}
|
}
|
||||||
|
|
||||||
v.Tokens = append(v.Tokens, t.content)
|
v.Tokens = append(v.Tokens, t.content)
|
||||||
@@ -92,7 +111,15 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
|||||||
return &v, nil
|
return &v, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseAdditionalSpecialTokens(fsys fs.FS) ([]string, error) {
|
type specialToken struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
Lstrip bool `json:"lstrip"`
|
||||||
|
Normalized bool `json:"normalized"`
|
||||||
|
Rstrip bool `json:"rstrip"`
|
||||||
|
SingleWord bool `json:"single_word"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAdditionalSpecialTokens(fsys fs.FS) ([]specialToken, error) {
|
||||||
f, err := fsys.Open("special_tokens_map.json")
|
f, err := fsys.Open("special_tokens_map.json")
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@@ -102,12 +129,43 @@ func parseAdditionalSpecialTokens(fsys fs.FS) ([]string, error) {
|
|||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
var m struct {
|
var m struct {
|
||||||
AdditionalSpecialTokens []string `json:"additional_special_tokens"`
|
AdditionalSpecialTokens any `json:"additional_special_tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.NewDecoder(f).Decode(&m); err != nil {
|
if err := json.NewDecoder(f).Decode(&m); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.AdditionalSpecialTokens, nil
|
var ast []specialToken
|
||||||
|
|
||||||
|
switch st := m.AdditionalSpecialTokens.(type) {
|
||||||
|
case []string:
|
||||||
|
for _, s := range st {
|
||||||
|
ast = append(ast, specialToken{Content: s})
|
||||||
|
}
|
||||||
|
case []any:
|
||||||
|
for _, s := range st {
|
||||||
|
// marshal and unmarshal the object to get the special token
|
||||||
|
tMap := s.(map[string]any)
|
||||||
|
data, err := json.Marshal(tMap)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var token specialToken
|
||||||
|
err = json.Unmarshal(data, &token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ast = append(ast, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
slog.Warn("special token", "unknown token", reflect.TypeOf(st))
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("spm tokenizer", "additional tokens", ast)
|
||||||
|
|
||||||
|
return ast, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Determine if the given ROCm lib directory is usable by checking for existence of some glob patterns
|
// Determine if the given ROCm lib directory is usable by checking for existence of some glob patterns
|
||||||
@@ -41,14 +39,11 @@ func commonAMDValidateLibDir() (string, error) {
|
|||||||
// Favor our bundled version
|
// Favor our bundled version
|
||||||
|
|
||||||
// Installer payload location if we're running the installed binary
|
// Installer payload location if we're running the installed binary
|
||||||
exe, err := os.Executable()
|
rocmTargetDir := filepath.Join(LibOllamaPath, "rocm")
|
||||||
if err == nil {
|
|
||||||
rocmTargetDir := filepath.Join(filepath.Dir(exe), envconfig.LibRelativeToExe(), "lib", "ollama")
|
|
||||||
if rocmLibUsable(rocmTargetDir) {
|
if rocmLibUsable(rocmTargetDir) {
|
||||||
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
|
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
|
||||||
return rocmTargetDir, nil
|
return rocmTargetDir, nil
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Prefer explicit HIP env var
|
// Prefer explicit HIP env var
|
||||||
hipPath := os.Getenv("HIP_PATH")
|
hipPath := os.Getenv("HIP_PATH")
|
||||||
|
|||||||
@@ -77,8 +77,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
|||||||
|
|
||||||
gfxOverride := envconfig.HsaOverrideGfxVersion()
|
gfxOverride := envconfig.HsaOverrideGfxVersion()
|
||||||
var supported []string
|
var supported []string
|
||||||
depPaths := LibraryDirs()
|
var libDir string
|
||||||
libDir := ""
|
|
||||||
|
|
||||||
// The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract
|
// The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract
|
||||||
// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
|
// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
|
||||||
@@ -353,9 +352,8 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
|||||||
})
|
})
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
depPaths = append(depPaths, libDir)
|
|
||||||
}
|
}
|
||||||
gpuInfo.DependencyPath = depPaths
|
gpuInfo.DependencyPath = []string{libDir}
|
||||||
|
|
||||||
if gfxOverride == "" {
|
if gfxOverride == "" {
|
||||||
// Only load supported list once
|
// Only load supported list once
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -50,14 +49,13 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
|||||||
slog.Info(err.Error())
|
slog.Info(err.Error())
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
depPaths := LibraryDirs()
|
|
||||||
libDir, err := AMDValidateLibDir()
|
libDir, err := AMDValidateLibDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("unable to verify rocm library: %w", err)
|
err = fmt.Errorf("unable to verify rocm library: %w", err)
|
||||||
slog.Warn(err.Error())
|
slog.Warn(err.Error())
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
depPaths = append(depPaths, libDir)
|
|
||||||
|
|
||||||
var supported []string
|
var supported []string
|
||||||
gfxOverride := envconfig.HsaOverrideGfxVersion()
|
gfxOverride := envconfig.HsaOverrideGfxVersion()
|
||||||
@@ -113,7 +111,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
|||||||
UnreliableFreeMemory: true,
|
UnreliableFreeMemory: true,
|
||||||
|
|
||||||
ID: strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices
|
ID: strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices
|
||||||
DependencyPath: depPaths,
|
DependencyPath: []string{libDir},
|
||||||
MinimumMemory: rocmMinimumMemory,
|
MinimumMemory: rocmMinimumMemory,
|
||||||
Name: name,
|
Name: name,
|
||||||
Compute: gfx,
|
Compute: gfx,
|
||||||
@@ -164,9 +162,7 @@ func AMDValidateLibDir() (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Installer payload (if we're running from some other location)
|
// Installer payload (if we're running from some other location)
|
||||||
localAppData := os.Getenv("LOCALAPPDATA")
|
rocmTargetDir := filepath.Join(LibOllamaPath, "rocm")
|
||||||
appDir := filepath.Join(localAppData, "Programs", "Ollama")
|
|
||||||
rocmTargetDir := filepath.Join(appDir, envconfig.LibRelativeToExe(), "lib", "ollama")
|
|
||||||
if rocmLibUsable(rocmTargetDir) {
|
if rocmLibUsable(rocmTargetDir) {
|
||||||
slog.Debug("detected ollama installed ROCm at " + rocmTargetDir)
|
slog.Debug("detected ollama installed ROCm at " + rocmTargetDir)
|
||||||
return rocmTargetDir, nil
|
return rocmTargetDir, nil
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ func IsNUMA() bool {
|
|||||||
// numa support in llama.cpp is linux only
|
// numa support in llama.cpp is linux only
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
ids := map[string]interface{}{}
|
ids := map[string]any{}
|
||||||
packageIds, _ := filepath.Glob("/sys/devices/system/cpu/cpu*/topology/physical_package_id")
|
packageIds, _ := filepath.Glob("/sys/devices/system/cpu/cpu*/topology/physical_package_id")
|
||||||
for _, packageId := range packageIds {
|
for _, packageId := range packageIds {
|
||||||
id, err := os.ReadFile(packageId)
|
id, err := os.ReadFile(packageId)
|
||||||
|
|||||||
@@ -57,7 +57,8 @@ func cudaVariant(gpuInfo CudaGPUInfo) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if gpuInfo.computeMajor < 6 || gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
|
// driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers
|
||||||
|
if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
|
||||||
return "v11"
|
return "v11"
|
||||||
}
|
}
|
||||||
return "v12"
|
return "v12"
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/runners"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type cudaHandles struct {
|
type cudaHandles struct {
|
||||||
@@ -101,15 +100,7 @@ func initCudaHandles() *cudaHandles {
|
|||||||
|
|
||||||
// Aligned with driver, we can't carry as payloads
|
// Aligned with driver, we can't carry as payloads
|
||||||
nvcudaMgmtPatterns := NvcudaGlobs
|
nvcudaMgmtPatterns := NvcudaGlobs
|
||||||
|
cudartMgmtPatterns = append(cudartMgmtPatterns, filepath.Join(LibOllamaPath, "cuda_v*", CudartMgmtName))
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
localAppData := os.Getenv("LOCALAPPDATA")
|
|
||||||
cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", CudartMgmtName)}
|
|
||||||
}
|
|
||||||
libDirs := LibraryDirs()
|
|
||||||
for _, d := range libDirs {
|
|
||||||
cudartMgmtPatterns = append(cudartMgmtPatterns, filepath.Join(d, CudartMgmtName))
|
|
||||||
}
|
|
||||||
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartGlobs...)
|
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartGlobs...)
|
||||||
|
|
||||||
if len(NvmlGlobs) > 0 {
|
if len(NvmlGlobs) > 0 {
|
||||||
@@ -240,7 +231,7 @@ func GetGPUInfo() GpuInfoList {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("error looking up system memory", "error", err)
|
slog.Warn("error looking up system memory", "error", err)
|
||||||
}
|
}
|
||||||
depPaths := LibraryDirs()
|
|
||||||
details, err := GetCPUDetails()
|
details, err := GetCPUDetails()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("failed to lookup CPU details", "error", err)
|
slog.Warn("failed to lookup CPU details", "error", err)
|
||||||
@@ -250,9 +241,7 @@ func GetGPUInfo() GpuInfoList {
|
|||||||
GpuInfo: GpuInfo{
|
GpuInfo: GpuInfo{
|
||||||
memInfo: mem,
|
memInfo: mem,
|
||||||
Library: "cpu",
|
Library: "cpu",
|
||||||
Variant: runners.GetCPUCapability().String(),
|
|
||||||
ID: "0",
|
ID: "0",
|
||||||
DependencyPath: depPaths,
|
|
||||||
},
|
},
|
||||||
CPUs: details,
|
CPUs: details,
|
||||||
},
|
},
|
||||||
@@ -294,17 +283,13 @@ func GetGPUInfo() GpuInfoList {
|
|||||||
gpuInfo.DriverMajor = driverMajor
|
gpuInfo.DriverMajor = driverMajor
|
||||||
gpuInfo.DriverMinor = driverMinor
|
gpuInfo.DriverMinor = driverMinor
|
||||||
variant := cudaVariant(gpuInfo)
|
variant := cudaVariant(gpuInfo)
|
||||||
if depPaths != nil {
|
|
||||||
gpuInfo.DependencyPath = depPaths
|
// Start with our bundled libraries
|
||||||
// Check for variant specific directory
|
|
||||||
if variant != "" {
|
if variant != "" {
|
||||||
for _, d := range depPaths {
|
variantPath := filepath.Join(LibOllamaPath, "cuda_"+variant)
|
||||||
if _, err := os.Stat(filepath.Join(d, "cuda_"+variant)); err == nil {
|
if _, err := os.Stat(variantPath); err == nil {
|
||||||
// Put the variant directory first in the search path to avoid runtime linking to the wrong library
|
// Put the variant directory first in the search path to avoid runtime linking to the wrong library
|
||||||
gpuInfo.DependencyPath = append([]string{filepath.Join(d, "cuda_"+variant)}, gpuInfo.DependencyPath...)
|
gpuInfo.DependencyPath = append([]string{variantPath}, gpuInfo.DependencyPath...)
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
|
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
|
||||||
@@ -376,7 +361,7 @@ func GetGPUInfo() GpuInfoList {
|
|||||||
gpuInfo.FreeMemory = uint64(memInfo.free)
|
gpuInfo.FreeMemory = uint64(memInfo.free)
|
||||||
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
|
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
|
||||||
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
|
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
|
||||||
gpuInfo.DependencyPath = depPaths
|
gpuInfo.DependencyPath = []string{LibOllamaPath}
|
||||||
oneapiGPUs = append(oneapiGPUs, gpuInfo)
|
oneapiGPUs = append(oneapiGPUs, gpuInfo)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -512,33 +497,30 @@ func GetGPUInfo() GpuInfoList {
|
|||||||
|
|
||||||
func FindGPULibs(baseLibName string, defaultPatterns []string) []string {
|
func FindGPULibs(baseLibName string, defaultPatterns []string) []string {
|
||||||
// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
|
// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
|
||||||
var ldPaths []string
|
|
||||||
gpuLibPaths := []string{}
|
gpuLibPaths := []string{}
|
||||||
slog.Debug("Searching for GPU library", "name", baseLibName)
|
slog.Debug("Searching for GPU library", "name", baseLibName)
|
||||||
|
|
||||||
// Start with our bundled libraries
|
// search our bundled libraries first
|
||||||
patterns := []string{}
|
patterns := []string{filepath.Join(LibOllamaPath, baseLibName)}
|
||||||
for _, d := range LibraryDirs() {
|
|
||||||
patterns = append(patterns, filepath.Join(d, baseLibName))
|
|
||||||
}
|
|
||||||
|
|
||||||
|
var ldPaths []string
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
case "windows":
|
case "windows":
|
||||||
ldPaths = strings.Split(os.Getenv("PATH"), ";")
|
ldPaths = strings.Split(os.Getenv("PATH"), string(os.PathListSeparator))
|
||||||
case "linux":
|
case "linux":
|
||||||
ldPaths = strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
|
ldPaths = strings.Split(os.Getenv("LD_LIBRARY_PATH"), string(os.PathListSeparator))
|
||||||
default:
|
|
||||||
return gpuLibPaths
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Then with whatever we find in the PATH/LD_LIBRARY_PATH
|
// then search the system's LD_LIBRARY_PATH
|
||||||
for _, ldPath := range ldPaths {
|
for _, p := range ldPaths {
|
||||||
d, err := filepath.Abs(ldPath)
|
p, err := filepath.Abs(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
patterns = append(patterns, filepath.Join(d, baseLibName))
|
patterns = append(patterns, filepath.Join(p, baseLibName))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// finally, search the default patterns provided by the caller
|
||||||
patterns = append(patterns, defaultPatterns...)
|
patterns = append(patterns, defaultPatterns...)
|
||||||
slog.Debug("gpu library search", "globs", patterns)
|
slog.Debug("gpu library search", "globs", patterns)
|
||||||
for _, pattern := range patterns {
|
for _, pattern := range patterns {
|
||||||
@@ -715,28 +697,6 @@ func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func LibraryDirs() []string {
|
|
||||||
// dependencies can exist wherever we found the runners (e.g. build tree for developers) and relative to the executable
|
|
||||||
// This can be simplified once we no longer carry runners as payloads
|
|
||||||
paths := []string{}
|
|
||||||
appExe, err := os.Executable()
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn("failed to lookup executable path", "error", err)
|
|
||||||
} else {
|
|
||||||
appRelative := filepath.Join(filepath.Dir(appExe), envconfig.LibRelativeToExe(), "lib", "ollama")
|
|
||||||
if _, err := os.Stat(appRelative); err == nil {
|
|
||||||
paths = append(paths, appRelative)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rDir := runners.Locate()
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn("unable to locate gpu dependency libraries", "error", err)
|
|
||||||
} else {
|
|
||||||
paths = append(paths, filepath.Dir(rDir))
|
|
||||||
}
|
|
||||||
return paths
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetSystemInfo() SystemInfo {
|
func GetSystemInfo() SystemInfo {
|
||||||
gpus := GetGPUInfo()
|
gpus := GetGPUInfo()
|
||||||
gpuMutex.Lock()
|
gpuMutex.Lock()
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/runners"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -28,7 +27,6 @@ func GetGPUInfo() GpuInfoList {
|
|||||||
return []GpuInfo{
|
return []GpuInfo{
|
||||||
{
|
{
|
||||||
Library: "cpu",
|
Library: "cpu",
|
||||||
Variant: runners.GetCPUCapability().String(),
|
|
||||||
memInfo: mem,
|
memInfo: mem,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -51,7 +49,6 @@ func GetCPUInfo() GpuInfoList {
|
|||||||
return []GpuInfo{
|
return []GpuInfo{
|
||||||
{
|
{
|
||||||
Library: "cpu",
|
Library: "cpu",
|
||||||
Variant: runners.GetCPUCapability().String(),
|
|
||||||
memInfo: mem,
|
memInfo: mem,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ func GetCPUDetails() ([]CPU, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
defer file.Close()
|
||||||
return linuxCPUDetails(file)
|
return linuxCPUDetails(file)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,13 +169,11 @@ func linuxCPUDetails(file io.Reader) ([]CPU, error) {
|
|||||||
for id, s := range socketByID {
|
for id, s := range socketByID {
|
||||||
s.CoreCount = len(coreBySocket[id])
|
s.CoreCount = len(coreBySocket[id])
|
||||||
s.ThreadCount = 0
|
s.ThreadCount = 0
|
||||||
for _, tc := range threadsByCoreBySocket[id] {
|
|
||||||
s.ThreadCount += tc
|
|
||||||
}
|
|
||||||
|
|
||||||
// This only works if HT is enabled, consider a more reliable model, maybe cache size comparisons?
|
// This only works if HT is enabled, consider a more reliable model, maybe cache size comparisons?
|
||||||
efficiencyCoreCount := 0
|
efficiencyCoreCount := 0
|
||||||
for _, threads := range threadsByCoreBySocket[id] {
|
for _, threads := range threadsByCoreBySocket[id] {
|
||||||
|
s.ThreadCount += threads
|
||||||
if threads == 1 {
|
if threads == 1 {
|
||||||
efficiencyCoreCount++
|
efficiencyCoreCount++
|
||||||
}
|
}
|
||||||
|
|||||||
56
discover/path.go
Normal file
56
discover/path.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
package discover
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LibPath is a path to lookup dynamic libraries
|
||||||
|
// in development it's usually 'build/lib/ollama'
|
||||||
|
// in distribution builds it's 'lib/ollama' on Windows
|
||||||
|
// '../lib/ollama' on Linux and the executable's directory on macOS
|
||||||
|
// note: distribution builds, additional GPU-specific libraries are
|
||||||
|
// found in subdirectories of the returned path, such as
|
||||||
|
// 'cuda_v11', 'cuda_v12', 'rocm', etc.
|
||||||
|
var LibOllamaPath string = func() string {
|
||||||
|
exe, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||||
|
exe = eval
|
||||||
|
}
|
||||||
|
|
||||||
|
var libPath string
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "windows":
|
||||||
|
libPath = filepath.Join(filepath.Dir(exe), "lib", "ollama")
|
||||||
|
case "linux":
|
||||||
|
libPath = filepath.Join(filepath.Dir(exe), "..", "lib", "ollama")
|
||||||
|
case "darwin":
|
||||||
|
libPath = filepath.Dir(exe)
|
||||||
|
}
|
||||||
|
|
||||||
|
cwd, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
paths := []string{
|
||||||
|
libPath,
|
||||||
|
|
||||||
|
// build paths for development
|
||||||
|
filepath.Join(filepath.Dir(exe), "build", "lib", "ollama"),
|
||||||
|
filepath.Join(cwd, "build", "lib", "ollama"),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range paths {
|
||||||
|
if _, err := os.Stat(p); err == nil {
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return filepath.Dir(exe)
|
||||||
|
}()
|
||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
|
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/runners"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type memInfo struct {
|
type memInfo struct {
|
||||||
@@ -107,7 +106,7 @@ func (l GpuInfoList) ByLibrary() []GpuInfoList {
|
|||||||
for _, info := range l {
|
for _, info := range l {
|
||||||
found := false
|
found := false
|
||||||
requested := info.Library
|
requested := info.Library
|
||||||
if info.Variant != runners.CPUCapabilityNone.String() {
|
if info.Variant != "" {
|
||||||
requested += "_" + info.Variant
|
requested += "_" + info.Variant
|
||||||
}
|
}
|
||||||
for i, lib := range libs {
|
for i, lib := range libs {
|
||||||
|
|||||||
61
docs/api.md
61
docs/api.md
@@ -31,7 +31,7 @@ Certain endpoints stream responses as JSON objects. Streaming can be disabled by
|
|||||||
|
|
||||||
## Generate a completion
|
## Generate a completion
|
||||||
|
|
||||||
```shell
|
```
|
||||||
POST /api/generate
|
POST /api/generate
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -173,7 +173,7 @@ curl http://localhost:11434/api/generate -d '{
|
|||||||
|
|
||||||
##### Response
|
##### Response
|
||||||
|
|
||||||
```json
|
```json5
|
||||||
{
|
{
|
||||||
"model": "codellama:code",
|
"model": "codellama:code",
|
||||||
"created_at": "2024-07-22T20:47:51.147561Z",
|
"created_at": "2024-07-22T20:47:51.147561Z",
|
||||||
@@ -306,7 +306,7 @@ curl http://localhost:11434/api/generate -d '{
|
|||||||
|
|
||||||
#### Response
|
#### Response
|
||||||
|
|
||||||
```
|
```json
|
||||||
{
|
{
|
||||||
"model": "llava",
|
"model": "llava",
|
||||||
"created_at": "2023-11-03T15:36:02.583064Z",
|
"created_at": "2023-11-03T15:36:02.583064Z",
|
||||||
@@ -485,7 +485,7 @@ A single JSON object is returned:
|
|||||||
|
|
||||||
## Generate a chat completion
|
## Generate a chat completion
|
||||||
|
|
||||||
```shell
|
```
|
||||||
POST /api/chat
|
POST /api/chat
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -495,14 +495,14 @@ Generate the next message in a chat with a provided model. This is a streaming e
|
|||||||
|
|
||||||
- `model`: (required) the [model name](#model-names)
|
- `model`: (required) the [model name](#model-names)
|
||||||
- `messages`: the messages of the chat, this can be used to keep a chat memory
|
- `messages`: the messages of the chat, this can be used to keep a chat memory
|
||||||
- `tools`: tools for the model to use if supported. Requires `stream` to be set to `false`
|
- `tools`: list of tools in JSON for the model to use if supported
|
||||||
|
|
||||||
The `message` object has the following fields:
|
The `message` object has the following fields:
|
||||||
|
|
||||||
- `role`: the role of the message, either `system`, `user`, `assistant`, or `tool`
|
- `role`: the role of the message, either `system`, `user`, `assistant`, or `tool`
|
||||||
- `content`: the content of the message
|
- `content`: the content of the message
|
||||||
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
|
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
|
||||||
- `tool_calls` (optional): a list of tools the model wants to use
|
- `tool_calls` (optional): a list of tools in JSON that the model wants to use
|
||||||
|
|
||||||
Advanced parameters (optional):
|
Advanced parameters (optional):
|
||||||
|
|
||||||
@@ -558,6 +558,10 @@ Final response:
|
|||||||
{
|
{
|
||||||
"model": "llama3.2",
|
"model": "llama3.2",
|
||||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": ""
|
||||||
|
},
|
||||||
"done": true,
|
"done": true,
|
||||||
"total_duration": 4883583458,
|
"total_duration": 4883583458,
|
||||||
"load_duration": 1334875,
|
"load_duration": 1334875,
|
||||||
@@ -795,7 +799,7 @@ curl http://localhost:11434/api/chat -d '{
|
|||||||
|
|
||||||
##### Request
|
##### Request
|
||||||
|
|
||||||
```
|
```shell
|
||||||
curl http://localhost:11434/api/chat -d '{
|
curl http://localhost:11434/api/chat -d '{
|
||||||
"model": "llama3.2",
|
"model": "llama3.2",
|
||||||
"messages": [
|
"messages": [
|
||||||
@@ -870,7 +874,7 @@ If the messages array is empty, the model will be loaded into memory.
|
|||||||
|
|
||||||
##### Request
|
##### Request
|
||||||
|
|
||||||
```
|
```shell
|
||||||
curl http://localhost:11434/api/chat -d '{
|
curl http://localhost:11434/api/chat -d '{
|
||||||
"model": "llama3.2",
|
"model": "llama3.2",
|
||||||
"messages": []
|
"messages": []
|
||||||
@@ -878,6 +882,7 @@ curl http://localhost:11434/api/chat -d '{
|
|||||||
```
|
```
|
||||||
|
|
||||||
##### Response
|
##### Response
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"model": "llama3.2",
|
"model": "llama3.2",
|
||||||
@@ -897,7 +902,7 @@ If the messages array is empty and the `keep_alive` parameter is set to `0`, a m
|
|||||||
|
|
||||||
##### Request
|
##### Request
|
||||||
|
|
||||||
```
|
```shell
|
||||||
curl http://localhost:11434/api/chat -d '{
|
curl http://localhost:11434/api/chat -d '{
|
||||||
"model": "llama3.2",
|
"model": "llama3.2",
|
||||||
"messages": [],
|
"messages": [],
|
||||||
@@ -924,7 +929,7 @@ A single JSON object is returned:
|
|||||||
|
|
||||||
## Create a Model
|
## Create a Model
|
||||||
|
|
||||||
```shell
|
```
|
||||||
POST /api/create
|
POST /api/create
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1020,7 +1025,7 @@ curl http://localhost:11434/api/create -d '{
|
|||||||
|
|
||||||
A stream of JSON objects is returned:
|
A stream of JSON objects is returned:
|
||||||
|
|
||||||
```
|
```json
|
||||||
{"status":"quantizing F16 model to Q4_K_M"}
|
{"status":"quantizing F16 model to Q4_K_M"}
|
||||||
{"status":"creating new layer sha256:667b0c1932bc6ffc593ed1d03f895bf2dc8dc6df21db3042284a6f4416b06a29"}
|
{"status":"creating new layer sha256:667b0c1932bc6ffc593ed1d03f895bf2dc8dc6df21db3042284a6f4416b06a29"}
|
||||||
{"status":"using existing layer sha256:11ce4ee3e170f6adebac9a991c22e22ab3f8530e154ee669954c4bc73061c258"}
|
{"status":"using existing layer sha256:11ce4ee3e170f6adebac9a991c22e22ab3f8530e154ee669954c4bc73061c258"}
|
||||||
@@ -1051,7 +1056,7 @@ curl http://localhost:11434/api/create -d '{
|
|||||||
|
|
||||||
A stream of JSON objects is returned:
|
A stream of JSON objects is returned:
|
||||||
|
|
||||||
```
|
```json
|
||||||
{"status":"parsing GGUF"}
|
{"status":"parsing GGUF"}
|
||||||
{"status":"using existing layer sha256:432f310a77f4650a88d0fd59ecdd7cebed8d684bafea53cbff0473542964f0c3"}
|
{"status":"using existing layer sha256:432f310a77f4650a88d0fd59ecdd7cebed8d684bafea53cbff0473542964f0c3"}
|
||||||
{"status":"writing manifest"}
|
{"status":"writing manifest"}
|
||||||
@@ -1118,7 +1123,7 @@ Return 200 OK if the blob exists, 404 Not Found if it does not.
|
|||||||
|
|
||||||
## Push a Blob
|
## Push a Blob
|
||||||
|
|
||||||
```shell
|
```
|
||||||
POST /api/blobs/:digest
|
POST /api/blobs/:digest
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1142,7 +1147,7 @@ Return 201 Created if the blob was successfully created, 400 Bad Request if the
|
|||||||
|
|
||||||
## List Local Models
|
## List Local Models
|
||||||
|
|
||||||
```shell
|
```
|
||||||
GET /api/tags
|
GET /api/tags
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1195,7 +1200,7 @@ A single JSON object will be returned.
|
|||||||
|
|
||||||
## Show Model Information
|
## Show Model Information
|
||||||
|
|
||||||
```shell
|
```
|
||||||
POST /api/show
|
POST /api/show
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1212,13 +1217,13 @@ Show information about a model including details, modelfile, template, parameter
|
|||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/show -d '{
|
curl http://localhost:11434/api/show -d '{
|
||||||
"model": "llama3.2"
|
"model": "llava"
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Response
|
#### Response
|
||||||
|
|
||||||
```json
|
```json5
|
||||||
{
|
{
|
||||||
"modelfile": "# Modelfile generated by \"ollama show\"\n# To build a new Modelfile based on this one, replace the FROM line with:\n# FROM llava:latest\n\nFROM /Users/matt/.ollama/models/blobs/sha256:200765e1283640ffbd013184bf496e261032fa75b99498a9613be4e94d63ad52\nTEMPLATE \"\"\"{{ .System }}\nUSER: {{ .Prompt }}\nASSISTANT: \"\"\"\nPARAMETER num_ctx 4096\nPARAMETER stop \"\u003c/s\u003e\"\nPARAMETER stop \"USER:\"\nPARAMETER stop \"ASSISTANT:\"",
|
"modelfile": "# Modelfile generated by \"ollama show\"\n# To build a new Modelfile based on this one, replace the FROM line with:\n# FROM llava:latest\n\nFROM /Users/matt/.ollama/models/blobs/sha256:200765e1283640ffbd013184bf496e261032fa75b99498a9613be4e94d63ad52\nTEMPLATE \"\"\"{{ .System }}\nUSER: {{ .Prompt }}\nASSISTANT: \"\"\"\nPARAMETER num_ctx 4096\nPARAMETER stop \"\u003c/s\u003e\"\nPARAMETER stop \"USER:\"\nPARAMETER stop \"ASSISTANT:\"",
|
||||||
"parameters": "num_keep 24\nstop \"<|start_header_id|>\"\nstop \"<|end_header_id|>\"\nstop \"<|eot_id|>\"",
|
"parameters": "num_keep 24\nstop \"<|start_header_id|>\"\nstop \"<|end_header_id|>\"\nstop \"<|eot_id|>\"",
|
||||||
@@ -1255,13 +1260,17 @@ curl http://localhost:11434/api/show -d '{
|
|||||||
"tokenizer.ggml.pre": "llama-bpe",
|
"tokenizer.ggml.pre": "llama-bpe",
|
||||||
"tokenizer.ggml.token_type": [], // populates if `verbose=true`
|
"tokenizer.ggml.token_type": [], // populates if `verbose=true`
|
||||||
"tokenizer.ggml.tokens": [] // populates if `verbose=true`
|
"tokenizer.ggml.tokens": [] // populates if `verbose=true`
|
||||||
}
|
},
|
||||||
|
"capabilities": [
|
||||||
|
"completion",
|
||||||
|
"vision"
|
||||||
|
],
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## Copy a Model
|
## Copy a Model
|
||||||
|
|
||||||
```shell
|
```
|
||||||
POST /api/copy
|
POST /api/copy
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1284,7 +1293,7 @@ Returns a 200 OK if successful, or a 404 Not Found if the source model doesn't e
|
|||||||
|
|
||||||
## Delete a Model
|
## Delete a Model
|
||||||
|
|
||||||
```shell
|
```
|
||||||
DELETE /api/delete
|
DELETE /api/delete
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1310,7 +1319,7 @@ Returns a 200 OK if successful, 404 Not Found if the model to be deleted doesn't
|
|||||||
|
|
||||||
## Pull a Model
|
## Pull a Model
|
||||||
|
|
||||||
```shell
|
```
|
||||||
POST /api/pull
|
POST /api/pull
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1382,7 +1391,7 @@ if `stream` is set to false, then the response is a single JSON object:
|
|||||||
|
|
||||||
## Push a Model
|
## Push a Model
|
||||||
|
|
||||||
```shell
|
```
|
||||||
POST /api/push
|
POST /api/push
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1447,7 +1456,7 @@ If `stream` is set to `false`, then the response is a single JSON object:
|
|||||||
|
|
||||||
## Generate Embeddings
|
## Generate Embeddings
|
||||||
|
|
||||||
```shell
|
```
|
||||||
POST /api/embed
|
POST /api/embed
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1515,7 +1524,7 @@ curl http://localhost:11434/api/embed -d '{
|
|||||||
```
|
```
|
||||||
|
|
||||||
## List Running Models
|
## List Running Models
|
||||||
```shell
|
```
|
||||||
GET /api/ps
|
GET /api/ps
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1562,7 +1571,7 @@ A single JSON object will be returned.
|
|||||||
|
|
||||||
> Note: this endpoint has been superseded by `/api/embed`
|
> Note: this endpoint has been superseded by `/api/embed`
|
||||||
|
|
||||||
```shell
|
```
|
||||||
POST /api/embeddings
|
POST /api/embeddings
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1602,7 +1611,7 @@ curl http://localhost:11434/api/embeddings -d '{
|
|||||||
|
|
||||||
## Version
|
## Version
|
||||||
|
|
||||||
```shell
|
```
|
||||||
GET /api/version
|
GET /api/version
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
59
docs/benchmark.md
Normal file
59
docs/benchmark.md
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
# Benchmark
|
||||||
|
|
||||||
|
Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes.
|
||||||
|
|
||||||
|
## When to use
|
||||||
|
|
||||||
|
Run these benchmarks when:
|
||||||
|
- Making changes to the model inference engine
|
||||||
|
- Modifying model loading/unloading logic
|
||||||
|
- Changing prompt processing or token generation code
|
||||||
|
- Implementing a new model architecture
|
||||||
|
- Testing performance across different hardware setups
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
- Ollama server running locally with `ollama serve` on `127.0.0.1:11434`
|
||||||
|
## Usage and Examples
|
||||||
|
|
||||||
|
>[!NOTE]
|
||||||
|
>All commands must be run from the root directory of the Ollama project.
|
||||||
|
|
||||||
|
Basic syntax:
|
||||||
|
```bash
|
||||||
|
go test -bench=. ./benchmark/... -m $MODEL_NAME
|
||||||
|
```
|
||||||
|
|
||||||
|
Required flags:
|
||||||
|
- `-bench=.`: Run all benchmarks
|
||||||
|
- `-m`: Model name to benchmark
|
||||||
|
|
||||||
|
Optional flags:
|
||||||
|
- `-count N`: Number of times to run the benchmark (useful for statistical analysis)
|
||||||
|
- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes)
|
||||||
|
|
||||||
|
Common usage patterns:
|
||||||
|
|
||||||
|
Single benchmark run with a model specified:
|
||||||
|
```bash
|
||||||
|
go test -bench=. ./benchmark/... -m llama3.3
|
||||||
|
```
|
||||||
|
|
||||||
|
## Output metrics
|
||||||
|
|
||||||
|
The benchmark reports several key metrics:
|
||||||
|
|
||||||
|
- `gen_tok/s`: Generated tokens per second
|
||||||
|
- `prompt_tok/s`: Prompt processing tokens per second
|
||||||
|
- `ttft_ms`: Time to first token in milliseconds
|
||||||
|
- `load_ms`: Model load time in milliseconds
|
||||||
|
- `gen_tokens`: Total tokens generated
|
||||||
|
- `prompt_tokens`: Total prompt tokens processed
|
||||||
|
|
||||||
|
Each benchmark runs two scenarios:
|
||||||
|
- Cold start: Model is loaded from disk for each test
|
||||||
|
- Warm start: Model is pre-loaded in memory
|
||||||
|
|
||||||
|
Three prompt lengths are tested for each scenario:
|
||||||
|
- Short prompt (100 tokens)
|
||||||
|
- Medium prompt (500 tokens)
|
||||||
|
- Long prompt (1000 tokens)
|
||||||
@@ -1,165 +1,159 @@
|
|||||||
# Development
|
# Development
|
||||||
|
|
||||||
Install required tools:
|
Install prerequisites:
|
||||||
|
|
||||||
- go version 1.22 or higher
|
- [Go](https://go.dev/doc/install)
|
||||||
- OS specific C/C++ compiler (see below)
|
- C/C++ Compiler e.g. Clang on macOS, [TDM-GCC](https://github.com/jmeubank/tdm-gcc/releases/latest) (Windows amd64) or [llvm-mingw](https://github.com/mstorsjo/llvm-mingw) (Windows arm64), GCC/Clang on Linux.
|
||||||
- GNU Make
|
|
||||||
|
|
||||||
|
Then build and run Ollama from the root directory of the repository:
|
||||||
|
|
||||||
## Overview
|
```shell
|
||||||
|
go run . serve
|
||||||
Ollama uses a mix of Go and C/C++ code to interface with GPUs. The C/C++ code is compiled with both CGO and GPU library specific compilers. A set of GNU Makefiles are used to compile the project. GPU Libraries are auto-detected based on the typical environment variables used by the respective libraries, but can be overridden if necessary. The default make target will build the runners and primary Go Ollama application that will run within the repo directory. Throughout the examples below `-j 5` is suggested for 5 parallel jobs to speed up the build. You can adjust the job count based on your CPU Core count to reduce build times. If you want to relocate the built binaries, use the `dist` target and recursively copy the files in `./dist/$OS-$ARCH/` to your desired location. To learn more about the other make targets use `make help`
|
|
||||||
|
|
||||||
Once you have built the GPU/CPU runners, you can compile the main application with `go build .`
|
|
||||||
|
|
||||||
### MacOS
|
|
||||||
|
|
||||||
[Download Go](https://go.dev/dl/)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
make -j 5
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Now you can run `ollama`:
|
## macOS (Apple Silicon)
|
||||||
|
|
||||||
```bash
|
macOS Apple Silicon supports Metal which is built-in to the Ollama binary. No additional steps are required.
|
||||||
./ollama
|
|
||||||
|
## macOS (Intel)
|
||||||
|
|
||||||
|
Install prerequisites:
|
||||||
|
|
||||||
|
- [CMake](https://cmake.org/download/) or `brew install cmake`
|
||||||
|
|
||||||
|
Then, configure and build the project:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
cmake -B build
|
||||||
|
cmake --build build
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Xcode 15 warnings
|
Lastly, run Ollama:
|
||||||
|
|
||||||
If you are using Xcode newer than version 14, you may see a warning during `go build` about `ld: warning: ignoring duplicate libraries: '-lobjc'` due to Golang issue https://github.com/golang/go/issues/67799 which can be safely ignored. You can suppress the warning with `export CGO_LDFLAGS="-Wl,-no_warn_duplicate_libraries"`
|
```shell
|
||||||
|
go run . serve
|
||||||
### Linux
|
|
||||||
|
|
||||||
#### Linux CUDA (NVIDIA)
|
|
||||||
|
|
||||||
_Your operating system distribution may already have packages for NVIDIA CUDA. Distro packages are often preferable, but instructions are distro-specific. Please consult distro-specific docs for dependencies if available!_
|
|
||||||
|
|
||||||
Install `make`, `gcc` and `golang` as well as [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads)
|
|
||||||
development and runtime packages.
|
|
||||||
|
|
||||||
Typically the makefile will auto-detect CUDA, however, if your Linux distro
|
|
||||||
or installation approach uses alternative paths, you can specify the location by
|
|
||||||
overriding `CUDA_PATH` to the location of the CUDA toolkit. You can customize
|
|
||||||
a set of target CUDA architectures by setting `CUDA_ARCHITECTURES` (e.g. `CUDA_ARCHITECTURES=50;60;70`)
|
|
||||||
|
|
||||||
```
|
|
||||||
make -j 5
|
|
||||||
```
|
```
|
||||||
|
|
||||||
If both v11 and v12 tookkits are detected, runners for both major versions will be built by default. You can build just v12 with `make cuda_v12`
|
## Windows
|
||||||
|
|
||||||
#### Older Linux CUDA (NVIDIA)
|
Install prerequisites:
|
||||||
|
|
||||||
To support older GPUs with Compute Capability 3.5 or 3.7, you will need to use an older version of the Driver from [Unix Driver Archive](https://www.nvidia.com/en-us/drivers/unix/) (tested with 470) and [CUDA Toolkit Archive](https://developer.nvidia.com/cuda-toolkit-archive) (tested with cuda V11). When you build Ollama, you will need to set two make variable to adjust the minimum compute capability Ollama supports via `make -j 5 CUDA_ARCHITECTURES="35;37;50;52" EXTRA_GOLDFLAGS="\"-X=github.com/ollama/ollama/discover.CudaComputeMajorMin=3\" \"-X=github.com/ollama/ollama/discover.CudaComputeMinorMin=5\""`. To find the Compute Capability of your older GPU, refer to [GPU Compute Capability](https://developer.nvidia.com/cuda-gpus).
|
- [CMake](https://cmake.org/download/)
|
||||||
|
- [Visual Studio 2022](https://visualstudio.microsoft.com/downloads/) including the Native Desktop Workload
|
||||||
|
- (Optional) AMD GPU support
|
||||||
|
- [ROCm](https://rocm.docs.amd.com/en/latest/)
|
||||||
|
- [Ninja](https://github.com/ninja-build/ninja/releases)
|
||||||
|
- (Optional) NVIDIA GPU support
|
||||||
|
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
|
||||||
|
|
||||||
#### Linux ROCm (AMD)
|
Then, configure and build the project:
|
||||||
|
|
||||||
_Your operating system distribution may already have packages for AMD ROCm. Distro packages are often preferable, but instructions are distro-specific. Please consult distro-specific docs for dependencies if available!_
|
```shell
|
||||||
|
cmake -B build
|
||||||
Install [ROCm](https://rocm.docs.amd.com/en/latest/) development packages first, as well as `make`, `gcc`, and `golang`.
|
cmake --build build --config Release
|
||||||
|
|
||||||
Typically the build scripts will auto-detect ROCm, however, if your Linux distro
|
|
||||||
or installation approach uses unusual paths, you can specify the location by
|
|
||||||
specifying an environment variable `HIP_PATH` to the location of the ROCm
|
|
||||||
install (typically `/opt/rocm`). You can also customize
|
|
||||||
the AMD GPU targets by setting HIP_ARCHS (e.g. `HIP_ARCHS=gfx1101;gfx1102`)
|
|
||||||
|
|
||||||
```
|
|
||||||
make -j 5
|
|
||||||
```
|
```
|
||||||
|
|
||||||
ROCm requires elevated privileges to access the GPU at runtime. On most distros you can add your user account to the `render` group, or run as root.
|
> [!IMPORTANT]
|
||||||
|
> Building for ROCm requires additional flags:
|
||||||
|
> ```
|
||||||
|
> cmake -B build -G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++
|
||||||
|
> cmake --build build --config Release
|
||||||
|
> ```
|
||||||
|
|
||||||
#### Containerized Linux Build
|
|
||||||
|
|
||||||
If you have Docker and buildx available, you can build linux binaries with `./scripts/build_linux.sh` which has the CUDA and ROCm dependencies included. The resulting artifacts are placed in `./dist` and by default the script builds both arm64 and amd64 binaries. If you want to build only amd64, you can build with `PLATFORM=linux/amd64 ./scripts/build_linux.sh`
|
Lastly, run Ollama:
|
||||||
|
|
||||||
### Windows
|
```shell
|
||||||
|
go run . serve
|
||||||
The following tools are required as a minimal development environment to build CPU inference support.
|
|
||||||
|
|
||||||
- Go version 1.22 or higher
|
|
||||||
- https://go.dev/dl/
|
|
||||||
- Git
|
|
||||||
- https://git-scm.com/download/win
|
|
||||||
- clang with gcc compat and Make. There are multiple options on how to go about installing these tools on Windows. We have verified the following, but others may work as well:
|
|
||||||
- [MSYS2](https://www.msys2.org/)
|
|
||||||
- After installing, from an MSYS2 terminal, run `pacman -S mingw-w64-clang-x86_64-gcc-compat mingw-w64-clang-x86_64-clang make` to install the required tools
|
|
||||||
- Assuming you used the default install prefix for msys2 above, add `C:\msys64\clang64\bin` and `c:\msys64\usr\bin` to your environment variable `PATH` where you will perform the build steps below (e.g. system-wide, account-level, powershell, cmd, etc.)
|
|
||||||
|
|
||||||
> [!NOTE]
|
|
||||||
> Due to bugs in the GCC C++ library for unicode support, Ollama should be built with clang on windows.
|
|
||||||
|
|
||||||
```
|
|
||||||
make -j 5
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#### GPU Support
|
## Windows (ARM)
|
||||||
|
|
||||||
The GPU tools require the Microsoft native build tools. To build either CUDA or ROCm, you must first install MSVC via Visual Studio:
|
Windows ARM does not support additional acceleration libraries at this time. Do not use cmake, simply `go run` or `go build`.
|
||||||
|
|
||||||
- Make sure to select `Desktop development with C++` as a Workload during the Visual Studio install
|
## Linux
|
||||||
- You must complete the Visual Studio install and run it once **BEFORE** installing CUDA or ROCm for the tools to properly register
|
|
||||||
- Add the location of the **64 bit (x64)** compiler (`cl.exe`) to your `PATH`
|
|
||||||
- Note: the default Developer Shell may configure the 32 bit (x86) compiler which will lead to build failures. Ollama requires a 64 bit toolchain.
|
|
||||||
|
|
||||||
#### Windows CUDA (NVIDIA)
|
Install prerequisites:
|
||||||
|
|
||||||
In addition to the common Windows development tools and MSVC described above:
|
- [CMake](https://cmake.org/download/) or `sudo apt install cmake` or `sudo dnf install cmake`
|
||||||
|
- (Optional) AMD GPU support
|
||||||
|
- [ROCm](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/install/quick-start.html)
|
||||||
|
- (Optional) NVIDIA GPU support
|
||||||
|
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads)
|
||||||
|
|
||||||
- [NVIDIA CUDA](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html)
|
> [!IMPORTANT]
|
||||||
|
> Ensure prerequisites are in `PATH` before running CMake.
|
||||||
|
|
||||||
#### Windows ROCm (AMD Radeon)
|
|
||||||
|
|
||||||
In addition to the common Windows development tools and MSVC described above:
|
Then, configure and build the project:
|
||||||
|
|
||||||
- [AMD HIP](https://www.amd.com/en/developer/resources/rocm-hub/hip-sdk.html)
|
```shell
|
||||||
|
cmake -B build
|
||||||
#### Windows arm64
|
cmake --build build
|
||||||
|
|
||||||
The default `Developer PowerShell for VS 2022` may default to x86 which is not what you want. To ensure you get an arm64 development environment, start a plain PowerShell terminal and run:
|
|
||||||
|
|
||||||
```powershell
|
|
||||||
import-module 'C:\\Program Files\\Microsoft Visual Studio\\2022\\Community\\Common7\\Tools\\Microsoft.VisualStudio.DevShell.dll'
|
|
||||||
Enter-VsDevShell -Arch arm64 -vsinstallpath 'C:\\Program Files\\Microsoft Visual Studio\\2022\\Community' -skipautomaticlocation
|
|
||||||
```
|
```
|
||||||
|
|
||||||
You can confirm with `write-host $env:VSCMD_ARG_TGT_ARCH`
|
Lastly, run Ollama:
|
||||||
|
|
||||||
Follow the instructions at https://www.msys2.org/wiki/arm64/ to set up an arm64 msys2 environment. Ollama requires gcc and mingw32-make to compile, which is not currently available on Windows arm64, but a gcc compatibility adapter is available via `mingw-w64-clang-aarch64-gcc-compat`. At a minimum you will need to install the following:
|
```shell
|
||||||
|
go run . serve
|
||||||
```
|
|
||||||
pacman -S mingw-w64-clang-aarch64-clang mingw-w64-clang-aarch64-gcc-compat mingw-w64-clang-aarch64-make make
|
|
||||||
```
|
```
|
||||||
|
|
||||||
You will need to ensure your PATH includes go, cmake, gcc and clang mingw32-make to build ollama from source. (typically `C:\msys64\clangarm64\bin\`)
|
## Docker
|
||||||
|
|
||||||
|
```shell
|
||||||
## Advanced CPU Vector Settings
|
docker build .
|
||||||
|
|
||||||
On x86, running `make` will compile several CPU runners which can run on different CPU families. At runtime, Ollama will auto-detect the best variation to load. If GPU libraries are present at build time, Ollama also compiles GPU runners with the `AVX` CPU vector feature enabled. This provides a good performance balance when loading large models that split across GPU and CPU with broad compatibility. Some users may prefer no vector extensions (e.g. older Xeon/Celeron processors, or hypervisors that mask the vector features) while other users may prefer turning on many more vector extensions to further improve performance for split model loads.
|
|
||||||
|
|
||||||
To customize the set of CPU vector features enabled for a CPU runner and all GPU runners, use CUSTOM_CPU_FLAGS during the build.
|
|
||||||
|
|
||||||
To build without any vector flags:
|
|
||||||
|
|
||||||
```
|
|
||||||
make CUSTOM_CPU_FLAGS=""
|
|
||||||
```
|
```
|
||||||
|
|
||||||
To build with both AVX and AVX2:
|
### ROCm
|
||||||
```
|
|
||||||
make CUSTOM_CPU_FLAGS=avx,avx2
|
```shell
|
||||||
|
docker build --build-arg FLAVOR=rocm .
|
||||||
```
|
```
|
||||||
|
|
||||||
To build with AVX512 features turned on:
|
## Running tests
|
||||||
|
|
||||||
```
|
To run tests, use `go test`:
|
||||||
make CUSTOM_CPU_FLAGS=avx,avx2,avx512,avx512vbmi,avx512vnni,avx512bf16
|
|
||||||
|
```shell
|
||||||
|
go test ./...
|
||||||
```
|
```
|
||||||
|
|
||||||
> [!NOTE]
|
> NOTE: In rare cirumstances, you may nedd to change a package using the new
|
||||||
> If you are experimenting with different flags, make sure to do a `make clean` between each change to ensure everything is rebuilt with the new compiler flags
|
> "synctest" package in go1.24.
|
||||||
|
>
|
||||||
|
> If you do not have the "synctest" package enabled, you will not see build or
|
||||||
|
> test failures resulting from your change(s), if any, locally, but CI will
|
||||||
|
> break.
|
||||||
|
>
|
||||||
|
> If you see failures in CI, you can either keep pushing changes to see if the
|
||||||
|
> CI build passes, or you can enable the "synctest" package locally to see the
|
||||||
|
> failures before pushing.
|
||||||
|
>
|
||||||
|
> To enable the "synctest" package for testing, run the following command:
|
||||||
|
>
|
||||||
|
> ```shell
|
||||||
|
> GOEXPERIMENT=synctest go test ./...
|
||||||
|
> ```
|
||||||
|
>
|
||||||
|
> If you wish to enable synctest for all go commands, you can set the
|
||||||
|
> `GOEXPERIMENT` environment variable in your shell profile or by using:
|
||||||
|
>
|
||||||
|
> ```shell
|
||||||
|
> go env -w GOEXPERIMENT=synctest
|
||||||
|
> ```
|
||||||
|
>
|
||||||
|
> Which will enable the "synctest" package for all go commands without needing
|
||||||
|
> to set it for all shell sessions.
|
||||||
|
>
|
||||||
|
> The synctest package is not required for production builds.
|
||||||
|
|
||||||
|
## Library detection
|
||||||
|
|
||||||
|
Ollama looks for acceleration libraries in the following paths relative to the `ollama` executable:
|
||||||
|
|
||||||
|
* `./lib/ollama` (Windows)
|
||||||
|
* `../lib/ollama` (Linux)
|
||||||
|
* `.` (macOS)
|
||||||
|
* `build/lib/ollama` (for development)
|
||||||
|
|
||||||
|
If the libraries are not found, Ollama will not run with any acceleration libraries.
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
### CPU only
|
### CPU only
|
||||||
|
|
||||||
```bash
|
```shell
|
||||||
docker run -d -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama
|
docker run -d -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -11,42 +11,46 @@ Install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-
|
|||||||
|
|
||||||
#### Install with Apt
|
#### Install with Apt
|
||||||
1. Configure the repository
|
1. Configure the repository
|
||||||
```bash
|
|
||||||
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey \
|
```shell
|
||||||
|
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey \
|
||||||
| sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg
|
| sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg
|
||||||
curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list \
|
curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list \
|
||||||
| sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' \
|
| sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' \
|
||||||
| sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
|
| sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Install the NVIDIA Container Toolkit packages
|
2. Install the NVIDIA Container Toolkit packages
|
||||||
```bash
|
|
||||||
sudo apt-get install -y nvidia-container-toolkit
|
```shell
|
||||||
```
|
sudo apt-get install -y nvidia-container-toolkit
|
||||||
|
```
|
||||||
|
|
||||||
#### Install with Yum or Dnf
|
#### Install with Yum or Dnf
|
||||||
1. Configure the repository
|
1. Configure the repository
|
||||||
|
|
||||||
```bash
|
```shell
|
||||||
curl -s -L https://nvidia.github.io/libnvidia-container/stable/rpm/nvidia-container-toolkit.repo \
|
curl -s -L https://nvidia.github.io/libnvidia-container/stable/rpm/nvidia-container-toolkit.repo \
|
||||||
| sudo tee /etc/yum.repos.d/nvidia-container-toolkit.repo
|
| sudo tee /etc/yum.repos.d/nvidia-container-toolkit.repo
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Install the NVIDIA Container Toolkit packages
|
2. Install the NVIDIA Container Toolkit packages
|
||||||
|
|
||||||
```bash
|
```shell
|
||||||
sudo yum install -y nvidia-container-toolkit
|
sudo yum install -y nvidia-container-toolkit
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Configure Docker to use Nvidia driver
|
#### Configure Docker to use Nvidia driver
|
||||||
```
|
|
||||||
|
```shell
|
||||||
sudo nvidia-ctk runtime configure --runtime=docker
|
sudo nvidia-ctk runtime configure --runtime=docker
|
||||||
sudo systemctl restart docker
|
sudo systemctl restart docker
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Start the container
|
#### Start the container
|
||||||
|
|
||||||
```bash
|
```shell
|
||||||
docker run -d --gpus=all -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama
|
docker run -d --gpus=all -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -57,7 +61,7 @@ docker run -d --gpus=all -v ollama:/root/.ollama -p 11434:11434 --name ollama ol
|
|||||||
|
|
||||||
To run Ollama using Docker with AMD GPUs, use the `rocm` tag and the following command:
|
To run Ollama using Docker with AMD GPUs, use the `rocm` tag and the following command:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
docker run -d --device /dev/kfd --device /dev/dri -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama:rocm
|
docker run -d --device /dev/kfd --device /dev/dri -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama:rocm
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -65,7 +69,7 @@ docker run -d --device /dev/kfd --device /dev/dri -v ollama:/root/.ollama -p 114
|
|||||||
|
|
||||||
Now you can run a model:
|
Now you can run a model:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
docker exec -it ollama ollama run llama3.2
|
docker exec -it ollama ollama run llama3.2
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
41
docs/faq.md
41
docs/faq.md
@@ -20,12 +20,18 @@ Please refer to the [GPU docs](./gpu.md).
|
|||||||
|
|
||||||
## How can I specify the context window size?
|
## How can I specify the context window size?
|
||||||
|
|
||||||
By default, Ollama uses a context window size of 2048 tokens.
|
By default, Ollama uses a context window size of 4096 tokens, unless you have a single GPU with <= 4 GB of VRAM, in which case it will default to 2048 tokens.
|
||||||
|
|
||||||
|
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
OLLAMA_CONTEXT_LENGTH=8192 ollama serve
|
||||||
|
```
|
||||||
|
|
||||||
To change this when using `ollama run`, use `/set parameter`:
|
To change this when using `ollama run`, use `/set parameter`:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
/set parameter num_ctx 4096
|
/set parameter num_ctx 8192
|
||||||
```
|
```
|
||||||
|
|
||||||
When using the API, specify the `num_ctx` parameter:
|
When using the API, specify the `num_ctx` parameter:
|
||||||
@@ -35,7 +41,7 @@ curl http://localhost:11434/api/generate -d '{
|
|||||||
"model": "llama3.2",
|
"model": "llama3.2",
|
||||||
"prompt": "Why is the sky blue?",
|
"prompt": "Why is the sky blue?",
|
||||||
"options": {
|
"options": {
|
||||||
"num_ctx": 4096
|
"num_ctx": 8192
|
||||||
}
|
}
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
@@ -46,10 +52,15 @@ Use the `ollama ps` command to see what models are currently loaded into memory.
|
|||||||
|
|
||||||
```shell
|
```shell
|
||||||
ollama ps
|
ollama ps
|
||||||
NAME ID SIZE PROCESSOR UNTIL
|
|
||||||
llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> **Output**:
|
||||||
|
>
|
||||||
|
> ```
|
||||||
|
> NAME ID SIZE PROCESSOR UNTIL
|
||||||
|
> llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
|
||||||
|
> ```
|
||||||
|
|
||||||
The `Processor` column will show which memory the model was loaded in to:
|
The `Processor` column will show which memory the model was loaded in to:
|
||||||
* `100% GPU` means the model was loaded entirely into the GPU
|
* `100% GPU` means the model was loaded entirely into the GPU
|
||||||
* `100% CPU` means the model was loaded entirely in system memory
|
* `100% CPU` means the model was loaded entirely in system memory
|
||||||
@@ -66,7 +77,7 @@ If Ollama is run as a macOS application, environment variables should be set usi
|
|||||||
1. For each environment variable, call `launchctl setenv`.
|
1. For each environment variable, call `launchctl setenv`.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
launchctl setenv OLLAMA_HOST "0.0.0.0"
|
launchctl setenv OLLAMA_HOST "0.0.0.0:11434"
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Restart Ollama application.
|
2. Restart Ollama application.
|
||||||
@@ -81,14 +92,14 @@ If Ollama is run as a systemd service, environment variables should be set using
|
|||||||
|
|
||||||
```ini
|
```ini
|
||||||
[Service]
|
[Service]
|
||||||
Environment="OLLAMA_HOST=0.0.0.0"
|
Environment="OLLAMA_HOST=0.0.0.0:11434"
|
||||||
```
|
```
|
||||||
|
|
||||||
3. Save and exit.
|
3. Save and exit.
|
||||||
|
|
||||||
4. Reload `systemd` and restart Ollama:
|
4. Reload `systemd` and restart Ollama:
|
||||||
|
|
||||||
```bash
|
```shell
|
||||||
systemctl daemon-reload
|
systemctl daemon-reload
|
||||||
systemctl restart ollama
|
systemctl restart ollama
|
||||||
```
|
```
|
||||||
@@ -182,6 +193,13 @@ cloudflared tunnel --url http://localhost:11434 --http-host-header="localhost:11
|
|||||||
|
|
||||||
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`.
|
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`.
|
||||||
|
|
||||||
|
For browser extensions, you'll need to explicitly allow the extension's origin pattern. Set `OLLAMA_ORIGINS` to include `chrome-extension://*`, `moz-extension://*`, and `safari-web-extension://*` if you wish to allow all browser extensions access, or specific extensions as needed:
|
||||||
|
|
||||||
|
```
|
||||||
|
# Allow all Chrome, Firefox, and Safari extensions
|
||||||
|
OLLAMA_ORIGINS=chrome-extension://*,moz-extension://*,safari-web-extension://* ollama serve
|
||||||
|
```
|
||||||
|
|
||||||
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
|
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
|
||||||
|
|
||||||
## Where are models stored?
|
## Where are models stored?
|
||||||
@@ -221,16 +239,19 @@ properties.
|
|||||||
If you are using the API you can preload a model by sending the Ollama server an empty request. This works with both the `/api/generate` and `/api/chat` API endpoints.
|
If you are using the API you can preload a model by sending the Ollama server an empty request. This works with both the `/api/generate` and `/api/chat` API endpoints.
|
||||||
|
|
||||||
To preload the mistral model using the generate endpoint, use:
|
To preload the mistral model using the generate endpoint, use:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/generate -d '{"model": "mistral"}'
|
curl http://localhost:11434/api/generate -d '{"model": "mistral"}'
|
||||||
```
|
```
|
||||||
|
|
||||||
To use the chat completions endpoint, use:
|
To use the chat completions endpoint, use:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/chat -d '{"model": "mistral"}'
|
curl http://localhost:11434/api/chat -d '{"model": "mistral"}'
|
||||||
```
|
```
|
||||||
|
|
||||||
To preload a model using the CLI, use the command:
|
To preload a model using the CLI, use the command:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
ollama run llama3.2 ""
|
ollama run llama3.2 ""
|
||||||
```
|
```
|
||||||
@@ -250,11 +271,13 @@ If you're using the API, use the `keep_alive` parameter with the `/api/generate`
|
|||||||
* '0' which will unload the model immediately after generating a response
|
* '0' which will unload the model immediately after generating a response
|
||||||
|
|
||||||
For example, to preload a model and leave it in memory use:
|
For example, to preload a model and leave it in memory use:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/generate -d '{"model": "llama3.2", "keep_alive": -1}'
|
curl http://localhost:11434/api/generate -d '{"model": "llama3.2", "keep_alive": -1}'
|
||||||
```
|
```
|
||||||
|
|
||||||
To unload the model and free up memory use:
|
To unload the model and free up memory use:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/generate -d '{"model": "llama3.2", "keep_alive": 0}'
|
curl http://localhost:11434/api/generate -d '{"model": "llama3.2", "keep_alive": 0}'
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ Check your compute compatibility to see if your card is supported:
|
|||||||
|
|
||||||
| Compute Capability | Family | Cards |
|
| Compute Capability | Family | Cards |
|
||||||
| ------------------ | ------------------- | ----------------------------------------------------------------------------------------------------------- |
|
| ------------------ | ------------------- | ----------------------------------------------------------------------------------------------------------- |
|
||||||
| 9.0 | NVIDIA | `H100` |
|
| 9.0 | NVIDIA | `H200` `H100` |
|
||||||
| 8.9 | GeForce RTX 40xx | `RTX 4090` `RTX 4080 SUPER` `RTX 4080` `RTX 4070 Ti SUPER` `RTX 4070 Ti` `RTX 4070 SUPER` `RTX 4070` `RTX 4060 Ti` `RTX 4060` |
|
| 8.9 | GeForce RTX 40xx | `RTX 4090` `RTX 4080 SUPER` `RTX 4080` `RTX 4070 Ti SUPER` `RTX 4070 Ti` `RTX 4070 SUPER` `RTX 4070` `RTX 4060 Ti` `RTX 4060` |
|
||||||
| | NVIDIA Professional | `L4` `L40` `RTX 6000` |
|
| | NVIDIA Professional | `L4` `L40` `RTX 6000` |
|
||||||
| 8.6 | GeForce RTX 30xx | `RTX 3090 Ti` `RTX 3090` `RTX 3080 Ti` `RTX 3080` `RTX 3070 Ti` `RTX 3070` `RTX 3060 Ti` `RTX 3060` `RTX 3050 Ti` `RTX 3050` |
|
| 8.6 | GeForce RTX 30xx | `RTX 3090 Ti` `RTX 3090` `RTX 3080 Ti` `RTX 3080` `RTX 3070 Ti` `RTX 3070` `RTX 3060 Ti` `RTX 3060` `RTX 3050 Ti` `RTX 3050` |
|
||||||
@@ -38,7 +38,7 @@ Numeric IDs may be used, however ordering may vary, so UUIDs are more reliable.
|
|||||||
You can discover the UUID of your GPUs by running `nvidia-smi -L` If you want to
|
You can discover the UUID of your GPUs by running `nvidia-smi -L` If you want to
|
||||||
ignore the GPUs and force CPU usage, use an invalid GPU ID (e.g., "-1")
|
ignore the GPUs and force CPU usage, use an invalid GPU ID (e.g., "-1")
|
||||||
|
|
||||||
### Laptop Suspend Resume
|
### Linux Suspend Resume
|
||||||
|
|
||||||
On linux, after a suspend/resume cycle, sometimes Ollama will fail to discover
|
On linux, after a suspend/resume cycle, sometimes Ollama will fail to discover
|
||||||
your NVIDIA GPU, and fallback to running on the CPU. You can workaround this
|
your NVIDIA GPU, and fallback to running on the CPU. You can workaround this
|
||||||
|
|||||||
@@ -20,13 +20,13 @@ Make sure that you use the same base model in the `FROM` command as you used to
|
|||||||
|
|
||||||
Now run `ollama create` from the directory where the `Modelfile` was created:
|
Now run `ollama create` from the directory where the `Modelfile` was created:
|
||||||
|
|
||||||
```bash
|
```shell
|
||||||
ollama create my-model
|
ollama create my-model
|
||||||
```
|
```
|
||||||
|
|
||||||
Lastly, test the model:
|
Lastly, test the model:
|
||||||
|
|
||||||
```bash
|
```shell
|
||||||
ollama run my-model
|
ollama run my-model
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ RestartSec=3
|
|||||||
Environment="PATH=$PATH"
|
Environment="PATH=$PATH"
|
||||||
|
|
||||||
[Install]
|
[Install]
|
||||||
WantedBy=default.target
|
WantedBy=multi-user.target
|
||||||
```
|
```
|
||||||
|
|
||||||
Then start the service:
|
Then start the service:
|
||||||
@@ -119,7 +119,7 @@ sudo systemctl status ollama
|
|||||||
|
|
||||||
To customize the installation of Ollama, you can edit the systemd service file or the environment variables by running:
|
To customize the installation of Ollama, you can edit the systemd service file or the environment variables by running:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
sudo systemctl edit ollama
|
sudo systemctl edit ollama
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -152,7 +152,7 @@ Use `OLLAMA_VERSION` environment variable with the install script to install a s
|
|||||||
For example:
|
For example:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.3.9 sh
|
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.5.7 sh
|
||||||
```
|
```
|
||||||
|
|
||||||
## Viewing logs
|
## Viewing logs
|
||||||
@@ -186,3 +186,9 @@ sudo rm -r /usr/share/ollama
|
|||||||
sudo userdel ollama
|
sudo userdel ollama
|
||||||
sudo groupdel ollama
|
sudo groupdel ollama
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Remove installed libraries:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
sudo rm -rf /usr/local/lib/ollama
|
||||||
|
```
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ A model file is the blueprint to create and share models with Ollama.
|
|||||||
|
|
||||||
The format of the `Modelfile`:
|
The format of the `Modelfile`:
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
# comment
|
# comment
|
||||||
INSTRUCTION arguments
|
INSTRUCTION arguments
|
||||||
```
|
```
|
||||||
@@ -49,7 +49,7 @@ INSTRUCTION arguments
|
|||||||
|
|
||||||
An example of a `Modelfile` creating a mario blueprint:
|
An example of a `Modelfile` creating a mario blueprint:
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
FROM llama3.2
|
FROM llama3.2
|
||||||
# sets the temperature to 1 [higher is more creative, lower is more coherent]
|
# sets the temperature to 1 [higher is more creative, lower is more coherent]
|
||||||
PARAMETER temperature 1
|
PARAMETER temperature 1
|
||||||
@@ -67,28 +67,32 @@ To use this:
|
|||||||
3. `ollama run choose-a-model-name`
|
3. `ollama run choose-a-model-name`
|
||||||
4. Start using the model!
|
4. Start using the model!
|
||||||
|
|
||||||
More examples are available in the [examples directory](../examples).
|
|
||||||
|
|
||||||
To view the Modelfile of a given model, use the `ollama show --modelfile` command.
|
To view the Modelfile of a given model, use the `ollama show --modelfile` command.
|
||||||
|
|
||||||
```bash
|
```shell
|
||||||
> ollama show --modelfile llama3.2
|
ollama show --modelfile llama3.2
|
||||||
# Modelfile generated by "ollama show"
|
```
|
||||||
# To build a new Modelfile based on this one, replace the FROM line with:
|
|
||||||
# FROM llama3.2:latest
|
|
||||||
FROM /Users/pdevine/.ollama/models/blobs/sha256-00e1317cbf74d901080d7100f57580ba8dd8de57203072dc6f668324ba545f29
|
|
||||||
TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
|
||||||
|
|
||||||
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
> **Output**:
|
||||||
|
>
|
||||||
|
> ```
|
||||||
|
> # Modelfile generated by "ollama show"
|
||||||
|
> # To build a new Modelfile based on this one, replace the FROM line with:
|
||||||
|
> # FROM llama3.2:latest
|
||||||
|
> FROM /Users/pdevine/.ollama/models/blobs/sha256-00e1317cbf74d901080d7100f57580ba8dd8de57203072dc6f668324ba545f29
|
||||||
|
> TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
||||||
|
>
|
||||||
|
> {{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
||||||
|
>
|
||||||
|
> {{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
||||||
|
>
|
||||||
|
> {{ .Response }}<|eot_id|>"""
|
||||||
|
> PARAMETER stop "<|start_header_id|>"
|
||||||
|
> PARAMETER stop "<|end_header_id|>"
|
||||||
|
> PARAMETER stop "<|eot_id|>"
|
||||||
|
> PARAMETER stop "<|reserved_special_token"
|
||||||
|
> ```
|
||||||
|
|
||||||
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
|
||||||
|
|
||||||
{{ .Response }}<|eot_id|>"""
|
|
||||||
PARAMETER stop "<|start_header_id|>"
|
|
||||||
PARAMETER stop "<|end_header_id|>"
|
|
||||||
PARAMETER stop "<|eot_id|>"
|
|
||||||
PARAMETER stop "<|reserved_special_token"
|
|
||||||
```
|
|
||||||
|
|
||||||
## Instructions
|
## Instructions
|
||||||
|
|
||||||
@@ -96,13 +100,13 @@ To view the Modelfile of a given model, use the `ollama show --modelfile` comman
|
|||||||
|
|
||||||
The `FROM` instruction defines the base model to use when creating a model.
|
The `FROM` instruction defines the base model to use when creating a model.
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
FROM <model name>:<tag>
|
FROM <model name>:<tag>
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Build from existing model
|
#### Build from existing model
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
FROM llama3.2
|
FROM llama3.2
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -113,7 +117,7 @@ Additional models can be found at:
|
|||||||
|
|
||||||
#### Build from a Safetensors model
|
#### Build from a Safetensors model
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
FROM <model directory>
|
FROM <model directory>
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -127,7 +131,7 @@ Currently supported model architectures:
|
|||||||
|
|
||||||
#### Build from a GGUF file
|
#### Build from a GGUF file
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
FROM ./ollama-model.gguf
|
FROM ./ollama-model.gguf
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -138,7 +142,7 @@ The GGUF file location should be specified as an absolute path or relative to th
|
|||||||
|
|
||||||
The `PARAMETER` instruction defines a parameter that can be set when the model is run.
|
The `PARAMETER` instruction defines a parameter that can be set when the model is run.
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
PARAMETER <parameter> <parametervalue>
|
PARAMETER <parameter> <parametervalue>
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -155,7 +159,6 @@ PARAMETER <parameter> <parametervalue>
|
|||||||
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
|
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
|
||||||
| seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | int | seed 42 |
|
| seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | int | seed 42 |
|
||||||
| stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate `stop` parameters in a modelfile. | string | stop "AI assistant:" |
|
| stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate `stop` parameters in a modelfile. | string | stop "AI assistant:" |
|
||||||
| tfs_z | Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1) | float | tfs_z 1 |
|
|
||||||
| num_predict | Maximum number of tokens to predict when generating text. (Default: -1, infinite generation) | int | num_predict 42 |
|
| num_predict | Maximum number of tokens to predict when generating text. (Default: -1, infinite generation) | int | num_predict 42 |
|
||||||
| top_k | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | top_k 40 |
|
| top_k | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | top_k 40 |
|
||||||
| top_p | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | float | top_p 0.9 |
|
| top_p | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | float | top_p 0.9 |
|
||||||
@@ -186,7 +189,7 @@ TEMPLATE """{{ if .System }}<|im_start|>system
|
|||||||
|
|
||||||
The `SYSTEM` instruction specifies the system message to be used in the template, if applicable.
|
The `SYSTEM` instruction specifies the system message to be used in the template, if applicable.
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
SYSTEM """<system message>"""
|
SYSTEM """<system message>"""
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -196,7 +199,7 @@ The `ADAPTER` instruction specifies a fine tuned LoRA adapter that should apply
|
|||||||
|
|
||||||
#### Safetensor adapter
|
#### Safetensor adapter
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
ADAPTER <path to safetensor adapter>
|
ADAPTER <path to safetensor adapter>
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -207,7 +210,7 @@ Currently supported Safetensor adapters:
|
|||||||
|
|
||||||
#### GGUF adapter
|
#### GGUF adapter
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
ADAPTER ./ollama-lora.gguf
|
ADAPTER ./ollama-lora.gguf
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -215,7 +218,7 @@ ADAPTER ./ollama-lora.gguf
|
|||||||
|
|
||||||
The `LICENSE` instruction allows you to specify the legal license under which the model used with this Modelfile is shared or distributed.
|
The `LICENSE` instruction allows you to specify the legal license under which the model used with this Modelfile is shared or distributed.
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
LICENSE """
|
LICENSE """
|
||||||
<license text>
|
<license text>
|
||||||
"""
|
"""
|
||||||
@@ -225,7 +228,7 @@ LICENSE """
|
|||||||
|
|
||||||
The `MESSAGE` instruction allows you to specify a message history for the model to use when responding. Use multiple iterations of the MESSAGE command to build up a conversation which will guide the model to answer in a similar way.
|
The `MESSAGE` instruction allows you to specify a message history for the model to use when responding. Use multiple iterations of the MESSAGE command to build up a conversation which will guide the model to answer in a similar way.
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
MESSAGE <role> <message>
|
MESSAGE <role> <message>
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -240,7 +243,7 @@ MESSAGE <role> <message>
|
|||||||
|
|
||||||
#### Example conversation
|
#### Example conversation
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
MESSAGE user Is Toronto in Canada?
|
MESSAGE user Is Toronto in Canada?
|
||||||
MESSAGE assistant yes
|
MESSAGE assistant yes
|
||||||
MESSAGE user Is Sacramento in Canada?
|
MESSAGE user Is Sacramento in Canada?
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# OpenAI compatibility
|
# OpenAI compatibility
|
||||||
|
|
||||||
> **Note:** OpenAI compatibility is experimental and is subject to major adjustments including breaking changes. For fully-featured access to the Ollama API, see the Ollama [Python library](https://github.com/ollama/ollama-python), [JavaScript library](https://github.com/ollama/ollama-js) and [REST API](https://github.com/ollama/ollama/blob/main/docs/api.md).
|
> [!NOTE]
|
||||||
|
> OpenAI compatibility is experimental and is subject to major adjustments including breaking changes. For fully-featured access to the Ollama API, see the Ollama [Python library](https://github.com/ollama/ollama-python), [JavaScript library](https://github.com/ollama/ollama-js) and [REST API](https://github.com/ollama/ollama/blob/main/docs/api.md).
|
||||||
|
|
||||||
Ollama provides experimental compatibility with parts of the [OpenAI API](https://platform.openai.com/docs/api-reference) to help connect existing applications to Ollama.
|
Ollama provides experimental compatibility with parts of the [OpenAI API](https://platform.openai.com/docs/api-reference) to help connect existing applications to Ollama.
|
||||||
|
|
||||||
@@ -59,8 +60,10 @@ embeddings = client.embeddings.create(
|
|||||||
input=["why is the sky blue?", "why is the grass green?"],
|
input=["why is the sky blue?", "why is the grass green?"],
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Structured outputs
|
#### Structured outputs
|
||||||
```py
|
|
||||||
|
```python
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
@@ -144,7 +147,7 @@ const embedding = await openai.embeddings.create({
|
|||||||
|
|
||||||
### `curl`
|
### `curl`
|
||||||
|
|
||||||
``` shell
|
```shell
|
||||||
curl http://localhost:11434/v1/chat/completions \
|
curl http://localhost:11434/v1/chat/completions \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d '{
|
-d '{
|
||||||
@@ -319,7 +322,7 @@ ollama pull llama3.2
|
|||||||
|
|
||||||
For tooling that relies on default OpenAI model names such as `gpt-3.5-turbo`, use `ollama cp` to copy an existing model name to a temporary name:
|
For tooling that relies on default OpenAI model names such as `gpt-3.5-turbo`, use `ollama cp` to copy an existing model name to a temporary name:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
ollama cp llama3.2 gpt-3.5-turbo
|
ollama cp llama3.2 gpt-3.5-turbo
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -343,7 +346,7 @@ curl http://localhost:11434/v1/chat/completions \
|
|||||||
|
|
||||||
The OpenAI API does not have a way of setting the context size for a model. If you need to change the context size, create a `Modelfile` which looks like:
|
The OpenAI API does not have a way of setting the context size for a model. If you need to change the context size, create a `Modelfile` which looks like:
|
||||||
|
|
||||||
```modelfile
|
```
|
||||||
FROM <some model>
|
FROM <some model>
|
||||||
PARAMETER num_ctx <context size>
|
PARAMETER num_ctx <context size>
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ A basic Go template consists of three main parts:
|
|||||||
|
|
||||||
Here's an example of a simple chat template:
|
Here's an example of a simple chat template:
|
||||||
|
|
||||||
```gotmpl
|
```go
|
||||||
{{- range .Messages }}
|
{{- range .Messages }}
|
||||||
{{ .Role }}: {{ .Content }}
|
{{ .Role }}: {{ .Content }}
|
||||||
{{- end }}
|
{{- end }}
|
||||||
@@ -162,6 +162,6 @@ CodeLlama [7B](https://ollama.com/library/codellama:7b-code) and [13B](https://o
|
|||||||
|
|
||||||
Codestral [22B](https://ollama.com/library/codestral:22b) supports fill-in-middle.
|
Codestral [22B](https://ollama.com/library/codestral:22b) supports fill-in-middle.
|
||||||
|
|
||||||
```gotmpl
|
```go
|
||||||
[SUFFIX]{{ .Suffix }}[PREFIX] {{ .Prompt }}
|
[SUFFIX]{{ .Suffix }}[PREFIX] {{ .Prompt }}
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ cat ~/.ollama/logs/server.log
|
|||||||
On **Linux** systems with systemd, the logs can be found with this command:
|
On **Linux** systems with systemd, the logs can be found with this command:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
journalctl -u ollama --no-pager
|
journalctl -u ollama --no-pager --follow --pager-end
|
||||||
```
|
```
|
||||||
|
|
||||||
When you run Ollama in a **container**, the logs go to stdout/stderr in the container:
|
When you run Ollama in a **container**, the logs go to stdout/stderr in the container:
|
||||||
@@ -17,6 +17,7 @@ When you run Ollama in a **container**, the logs go to stdout/stderr in the cont
|
|||||||
```shell
|
```shell
|
||||||
docker logs <container-name>
|
docker logs <container-name>
|
||||||
```
|
```
|
||||||
|
|
||||||
(Use `docker ps` to find the container name)
|
(Use `docker ps` to find the container name)
|
||||||
|
|
||||||
If manually running `ollama serve` in a terminal, the logs will be on that terminal.
|
If manually running `ollama serve` in a terminal, the logs will be on that terminal.
|
||||||
@@ -25,9 +26,9 @@ When you run Ollama on **Windows**, there are a few different locations. You can
|
|||||||
- `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log`
|
- `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log`
|
||||||
- `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH)
|
- `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH)
|
||||||
- `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored
|
- `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored
|
||||||
- `explorer %TEMP%` where temporary executable files are stored in one or more `ollama*` directories
|
|
||||||
|
|
||||||
To enable additional debug logging to help troubleshoot problems, first **Quit the running app from the tray menu** then in a powershell terminal
|
To enable additional debug logging to help troubleshoot problems, first **Quit the running app from the tray menu** then in a powershell terminal
|
||||||
|
|
||||||
```powershell
|
```powershell
|
||||||
$env:OLLAMA_DEBUG="1"
|
$env:OLLAMA_DEBUG="1"
|
||||||
& "ollama app.exe"
|
& "ollama app.exe"
|
||||||
@@ -49,12 +50,13 @@ Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v11 rocm_v5]
|
|||||||
|
|
||||||
You can set OLLAMA_LLM_LIBRARY to any of the available LLM libraries to bypass autodetection, so for example, if you have a CUDA card, but want to force the CPU LLM library with AVX2 vector support, use:
|
You can set OLLAMA_LLM_LIBRARY to any of the available LLM libraries to bypass autodetection, so for example, if you have a CUDA card, but want to force the CPU LLM library with AVX2 vector support, use:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
OLLAMA_LLM_LIBRARY="cpu_avx2" ollama serve
|
OLLAMA_LLM_LIBRARY="cpu_avx2" ollama serve
|
||||||
```
|
```
|
||||||
|
|
||||||
You can see what features your CPU has with the following.
|
You can see what features your CPU has with the following.
|
||||||
```
|
|
||||||
|
```shell
|
||||||
cat /proc/cpuinfo| grep flags | head -1
|
cat /proc/cpuinfo| grep flags | head -1
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -62,13 +64,13 @@ cat /proc/cpuinfo| grep flags | head -1
|
|||||||
|
|
||||||
If you run into problems on Linux and want to install an older version, or you'd like to try out a pre-release before it's officially released, you can tell the install script which version to install.
|
If you run into problems on Linux and want to install an older version, or you'd like to try out a pre-release before it's officially released, you can tell the install script which version to install.
|
||||||
|
|
||||||
```sh
|
```shell
|
||||||
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION="0.1.29" sh
|
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.5.7 sh
|
||||||
```
|
```
|
||||||
|
|
||||||
## Linux tmp noexec
|
## Linux docker
|
||||||
|
|
||||||
If your system is configured with the "noexec" flag where Ollama stores its temporary executable files, you can specify an alternate location by setting OLLAMA_TMPDIR to a location writable by the user ollama runs as. For example OLLAMA_TMPDIR=/usr/share/ollama/
|
If Ollama initially works on the GPU in a docker container, but then switches to running on CPU after some period of time with errors in the server log reporting GPU discovery failures, this can be resolved by disabling systemd cgroup management in Docker. Edit `/etc/docker/daemon.json` on the host and add `"exec-opts": ["native.cgroupdriver=cgroupfs"]` to the docker configuration.
|
||||||
|
|
||||||
## NVIDIA GPU Discovery
|
## NVIDIA GPU Discovery
|
||||||
|
|
||||||
@@ -97,8 +99,6 @@ On linux, AMD GPU access typically requires `video` and/or `render` group member
|
|||||||
|
|
||||||
When running in a container, in some Linux distributions and container runtimes, the ollama process may be unable to access the GPU. Use `ls -lnd /dev/kfd /dev/dri /dev/dri/*` on the host system to determine the **numeric** group IDs on your system, and pass additional `--group-add ...` arguments to the container so it can access the required devices. For example, in the following output `crw-rw---- 1 0 44 226, 0 Sep 16 16:55 /dev/dri/card0` the group ID column is `44`
|
When running in a container, in some Linux distributions and container runtimes, the ollama process may be unable to access the GPU. Use `ls -lnd /dev/kfd /dev/dri /dev/dri/*` on the host system to determine the **numeric** group IDs on your system, and pass additional `--group-add ...` arguments to the container so it can access the required devices. For example, in the following output `crw-rw---- 1 0 44 226, 0 Sep 16 16:55 /dev/dri/card0` the group ID column is `44`
|
||||||
|
|
||||||
If Ollama initially works on the GPU in a docker container, but then switches to running on CPU after some period of time with errors in the server log reporting GPU discovery failures, this can be resolved by disabling systemd cgroup management in Docker. Edit `/etc/docker/daemon.json` on the host and add `"exec-opts": ["native.cgroupdriver=cgroupfs"]` to the docker configuration.
|
|
||||||
|
|
||||||
If you are experiencing problems getting Ollama to correctly discover or use your GPU for inference, the following may help isolate the failure.
|
If you are experiencing problems getting Ollama to correctly discover or use your GPU for inference, the following may help isolate the failure.
|
||||||
- `AMD_LOG_LEVEL=3` Enable info log levels in the AMD HIP/ROCm libraries. This can help show more detailed error codes that can help troubleshoot problems
|
- `AMD_LOG_LEVEL=3` Enable info log levels in the AMD HIP/ROCm libraries. This can help show more detailed error codes that can help troubleshoot problems
|
||||||
- `OLLAMA_DEBUG=1` During GPU discovery additional information will be reported
|
- `OLLAMA_DEBUG=1` During GPU discovery additional information will be reported
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ If Ollama is already running, Quit the tray application and relaunch it from the
|
|||||||
## API Access
|
## API Access
|
||||||
|
|
||||||
Here's a quick example showing API access from `powershell`
|
Here's a quick example showing API access from `powershell`
|
||||||
|
|
||||||
```powershell
|
```powershell
|
||||||
(Invoke-WebRequest -method POST -Body '{"model":"llama3.2", "prompt":"Why is the sky blue?", "stream": false}' -uri http://localhost:11434/api/generate ).Content | ConvertFrom-json
|
(Invoke-WebRequest -method POST -Body '{"model":"llama3.2", "prompt":"Why is the sky blue?", "stream": false}' -uri http://localhost:11434/api/generate ).Content | ConvertFrom-json
|
||||||
```
|
```
|
||||||
@@ -54,14 +55,13 @@ Here's a quick example showing API access from `powershell`
|
|||||||
## Troubleshooting
|
## Troubleshooting
|
||||||
|
|
||||||
Ollama on Windows stores files in a few different locations. You can view them in
|
Ollama on Windows stores files in a few different locations. You can view them in
|
||||||
the explorer window by hitting `<cmd>+R` and type in:
|
the explorer window by hitting `<Ctrl>+R` and type in:
|
||||||
- `explorer %LOCALAPPDATA%\Ollama` contains logs, and downloaded updates
|
- `explorer %LOCALAPPDATA%\Ollama` contains logs, and downloaded updates
|
||||||
- *app.log* contains most resent logs from the GUI application
|
- *app.log* contains most resent logs from the GUI application
|
||||||
- *server.log* contains the most recent server logs
|
- *server.log* contains the most recent server logs
|
||||||
- *upgrade.log* contains log output for upgrades
|
- *upgrade.log* contains log output for upgrades
|
||||||
- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
|
- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
|
||||||
- `explorer %HOMEPATH%\.ollama` contains models and configuration
|
- `explorer %HOMEPATH%\.ollama` contains models and configuration
|
||||||
- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories
|
|
||||||
|
|
||||||
## Uninstall
|
## Uninstall
|
||||||
|
|
||||||
@@ -80,9 +80,11 @@ help you keep up to date.
|
|||||||
|
|
||||||
If you'd like to install or integrate Ollama as a service, a standalone
|
If you'd like to install or integrate Ollama as a service, a standalone
|
||||||
`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI
|
`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI
|
||||||
and GPU library dependencies for Nvidia and AMD. This allows for embedding
|
and GPU library dependencies for Nvidia. If you have an AMD GPU, also download
|
||||||
Ollama in existing applications, or running it as a system service via `ollama
|
and extract the additional ROCm package `ollama-windows-amd64-rocm.zip` into the
|
||||||
serve` with tools such as [NSSM](https://nssm.cc/).
|
same directory. This allows for embedding Ollama in existing applications, or
|
||||||
|
running it as a system service via `ollama serve` with tools such as
|
||||||
|
[NSSM](https://nssm.cc/).
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> If you are upgrading from a prior version, you should remove the old directories first.
|
> If you are upgrading from a prior version, you should remove the old directories first.
|
||||||
|
|||||||
@@ -53,8 +53,8 @@ func Host() *url.URL {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable.
|
// AllowedOrigins returns a list of allowed origins. AllowedOrigins can be configured via the OLLAMA_ORIGINS environment variable.
|
||||||
func Origins() (origins []string) {
|
func AllowedOrigins() (origins []string) {
|
||||||
if s := Var("OLLAMA_ORIGINS"); s != "" {
|
if s := Var("OLLAMA_ORIGINS"); s != "" {
|
||||||
origins = strings.Split(s, ",")
|
origins = strings.Split(s, ",")
|
||||||
}
|
}
|
||||||
@@ -73,6 +73,7 @@ func Origins() (origins []string) {
|
|||||||
"file://*",
|
"file://*",
|
||||||
"tauri://*",
|
"tauri://*",
|
||||||
"vscode-webview://*",
|
"vscode-webview://*",
|
||||||
|
"vscode-file://*",
|
||||||
)
|
)
|
||||||
|
|
||||||
return origins
|
return origins
|
||||||
@@ -165,6 +166,10 @@ var (
|
|||||||
IntelGPU = Bool("OLLAMA_INTEL_GPU")
|
IntelGPU = Bool("OLLAMA_INTEL_GPU")
|
||||||
// MultiUserCache optimizes prompt caching for multi-user scenarios
|
// MultiUserCache optimizes prompt caching for multi-user scenarios
|
||||||
MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE")
|
MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE")
|
||||||
|
// Enable the new Ollama engine
|
||||||
|
NewEngine = Bool("OLLAMA_NEW_ENGINE")
|
||||||
|
// ContextLength sets the default context length
|
||||||
|
ContextLength = Int64("OLLAMA_CONTEXT_LENGTH", -1)
|
||||||
)
|
)
|
||||||
|
|
||||||
func String(s string) func() string {
|
func String(s string) func() string {
|
||||||
@@ -222,6 +227,20 @@ func Uint64(key string, defaultValue uint64) func() uint64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Int64(key string, defaultValue int64) func() int64 {
|
||||||
|
return func() int64 {
|
||||||
|
if s := Var(key); s != "" {
|
||||||
|
if n, err := strconv.ParseInt(s, 10, 64); err != nil {
|
||||||
|
slog.Warn("invalid environment variable, using default", "key", key, "value", s, "default", defaultValue)
|
||||||
|
} else {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Set aside VRAM per GPU
|
// Set aside VRAM per GPU
|
||||||
var GpuOverhead = Uint64("OLLAMA_GPU_OVERHEAD", 0)
|
var GpuOverhead = Uint64("OLLAMA_GPU_OVERHEAD", 0)
|
||||||
|
|
||||||
@@ -247,9 +266,11 @@ func AsMap() map[string]EnvVar {
|
|||||||
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
|
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
|
||||||
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
|
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
|
||||||
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
|
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
|
||||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"},
|
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
|
||||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||||
|
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default 4096 or 2048 with low VRAM)"},
|
||||||
|
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||||
|
|
||||||
// Informational
|
// Informational
|
||||||
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
||||||
@@ -288,12 +309,3 @@ func Values() map[string]string {
|
|||||||
func Var(key string) string {
|
func Var(key string) string {
|
||||||
return strings.Trim(strings.TrimSpace(os.Getenv(key)), "\"'")
|
return strings.Trim(strings.TrimSpace(os.Getenv(key)), "\"'")
|
||||||
}
|
}
|
||||||
|
|
||||||
// On windows, we keep the binary at the top directory, but
|
|
||||||
// other platforms use a "bin" directory, so this returns ".."
|
|
||||||
func LibRelativeToExe() string {
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
return "."
|
|
||||||
}
|
|
||||||
return ".."
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ func TestOrigins(t *testing.T) {
|
|||||||
"file://*",
|
"file://*",
|
||||||
"tauri://*",
|
"tauri://*",
|
||||||
"vscode-webview://*",
|
"vscode-webview://*",
|
||||||
|
"vscode-file://*",
|
||||||
}},
|
}},
|
||||||
{"http://10.0.0.1", []string{
|
{"http://10.0.0.1", []string{
|
||||||
"http://10.0.0.1",
|
"http://10.0.0.1",
|
||||||
@@ -88,6 +89,7 @@ func TestOrigins(t *testing.T) {
|
|||||||
"file://*",
|
"file://*",
|
||||||
"tauri://*",
|
"tauri://*",
|
||||||
"vscode-webview://*",
|
"vscode-webview://*",
|
||||||
|
"vscode-file://*",
|
||||||
}},
|
}},
|
||||||
{"http://172.16.0.1,https://192.168.0.1", []string{
|
{"http://172.16.0.1,https://192.168.0.1", []string{
|
||||||
"http://172.16.0.1",
|
"http://172.16.0.1",
|
||||||
@@ -108,6 +110,7 @@ func TestOrigins(t *testing.T) {
|
|||||||
"file://*",
|
"file://*",
|
||||||
"tauri://*",
|
"tauri://*",
|
||||||
"vscode-webview://*",
|
"vscode-webview://*",
|
||||||
|
"vscode-file://*",
|
||||||
}},
|
}},
|
||||||
{"http://totally.safe,http://definitely.legit", []string{
|
{"http://totally.safe,http://definitely.legit", []string{
|
||||||
"http://totally.safe",
|
"http://totally.safe",
|
||||||
@@ -128,13 +131,14 @@ func TestOrigins(t *testing.T) {
|
|||||||
"file://*",
|
"file://*",
|
||||||
"tauri://*",
|
"tauri://*",
|
||||||
"vscode-webview://*",
|
"vscode-webview://*",
|
||||||
|
"vscode-file://*",
|
||||||
}},
|
}},
|
||||||
}
|
}
|
||||||
for _, tt := range cases {
|
for _, tt := range cases {
|
||||||
t.Run(tt.value, func(t *testing.T) {
|
t.Run(tt.value, func(t *testing.T) {
|
||||||
t.Setenv("OLLAMA_ORIGINS", tt.value)
|
t.Setenv("OLLAMA_ORIGINS", tt.value)
|
||||||
|
|
||||||
if diff := cmp.Diff(Origins(), tt.expect); diff != "" {
|
if diff := cmp.Diff(AllowedOrigins(), tt.expect); diff != "" {
|
||||||
t.Errorf("%s: mismatch (-want +got):\n%s", tt.value, diff)
|
t.Errorf("%s: mismatch (-want +got):\n%s", tt.value, diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -272,3 +276,19 @@ func TestVar(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestContextLength(t *testing.T) {
|
||||||
|
cases := map[string]int64{
|
||||||
|
"": -1,
|
||||||
|
"4096": 4096,
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range cases {
|
||||||
|
t.Run(k, func(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_CONTEXT_LENGTH", k)
|
||||||
|
if i := ContextLength(); i != v {
|
||||||
|
t.Errorf("%s: expected %d, got %d", k, v, i)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -40,8 +40,6 @@ func HumanBytes(b int64) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case value >= 100:
|
|
||||||
return fmt.Sprintf("%d %s", int(value), unit)
|
|
||||||
case value >= 10:
|
case value >= 10:
|
||||||
return fmt.Sprintf("%d %s", int(value), unit)
|
return fmt.Sprintf("%d %s", int(value), unit)
|
||||||
case value != math.Trunc(value):
|
case value != math.Trunc(value):
|
||||||
|
|||||||
91
format/bytes_test.go
Normal file
91
format/bytes_test.go
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
package format
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHumanBytes(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
input int64
|
||||||
|
expected string
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
// Test bytes (B)
|
||||||
|
{0, "0 B"},
|
||||||
|
{1, "1 B"},
|
||||||
|
{999, "999 B"},
|
||||||
|
|
||||||
|
// Test kilobytes (KB)
|
||||||
|
{1000, "1 KB"},
|
||||||
|
{1500, "1.5 KB"},
|
||||||
|
{999999, "999 KB"},
|
||||||
|
|
||||||
|
// Test megabytes (MB)
|
||||||
|
{1000000, "1 MB"},
|
||||||
|
{1500000, "1.5 MB"},
|
||||||
|
{999999999, "999 MB"},
|
||||||
|
|
||||||
|
// Test gigabytes (GB)
|
||||||
|
{1000000000, "1 GB"},
|
||||||
|
{1500000000, "1.5 GB"},
|
||||||
|
{999999999999, "999 GB"},
|
||||||
|
|
||||||
|
// Test terabytes (TB)
|
||||||
|
{1000000000000, "1 TB"},
|
||||||
|
{1500000000000, "1.5 TB"},
|
||||||
|
{1999999999999, "2.0 TB"},
|
||||||
|
|
||||||
|
// Test fractional values
|
||||||
|
{1234, "1.2 KB"},
|
||||||
|
{1234567, "1.2 MB"},
|
||||||
|
{1234567890, "1.2 GB"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.expected, func(t *testing.T) {
|
||||||
|
result := HumanBytes(tc.input)
|
||||||
|
if result != tc.expected {
|
||||||
|
t.Errorf("Expected %s, got %s", tc.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHumanBytes2(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
input uint64
|
||||||
|
expected string
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
// Test bytes (B)
|
||||||
|
{0, "0 B"},
|
||||||
|
{1, "1 B"},
|
||||||
|
{1023, "1023 B"},
|
||||||
|
|
||||||
|
// Test kibibytes (KiB)
|
||||||
|
{1024, "1.0 KiB"},
|
||||||
|
{1536, "1.5 KiB"},
|
||||||
|
{1048575, "1024.0 KiB"},
|
||||||
|
|
||||||
|
// Test mebibytes (MiB)
|
||||||
|
{1048576, "1.0 MiB"},
|
||||||
|
{1572864, "1.5 MiB"},
|
||||||
|
{1073741823, "1024.0 MiB"},
|
||||||
|
|
||||||
|
// Test gibibytes (GiB)
|
||||||
|
{1073741824, "1.0 GiB"},
|
||||||
|
{1610612736, "1.5 GiB"},
|
||||||
|
{2147483648, "2.0 GiB"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.expected, func(t *testing.T) {
|
||||||
|
result := HumanBytes2(tc.input)
|
||||||
|
if result != tc.expected {
|
||||||
|
t.Errorf("Expected %s, got %s", tc.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,6 +12,9 @@ func TestHumanNumber(t *testing.T) {
|
|||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{0, "0"},
|
{0, "0"},
|
||||||
|
{999, "999"},
|
||||||
|
{1000, "1K"},
|
||||||
|
{1001, "1K"},
|
||||||
{1000000, "1M"},
|
{1000000, "1M"},
|
||||||
{125000000, "125M"},
|
{125000000, "125M"},
|
||||||
{500500000, "500.50M"},
|
{500500000, "500.50M"},
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func assertEqual(t *testing.T, a interface{}, b interface{}) {
|
func assertEqual(t *testing.T, a any, b any) {
|
||||||
if a != b {
|
if a != b {
|
||||||
t.Errorf("Assert failed, expected %v, got %v", b, a)
|
t.Errorf("Assert failed, expected %v, got %v", b, a)
|
||||||
}
|
}
|
||||||
|
|||||||
13
fs/config.go
Normal file
13
fs/config.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package fs
|
||||||
|
|
||||||
|
type Config interface {
|
||||||
|
Architecture() string
|
||||||
|
String(string, ...string) string
|
||||||
|
Uint(string, ...uint32) uint32
|
||||||
|
Float(string, ...float32) float32
|
||||||
|
Bool(string, ...bool) bool
|
||||||
|
|
||||||
|
Strings(string, ...[]string) []string
|
||||||
|
Ints(string, ...[]int32) []int32
|
||||||
|
Floats(string, ...[]float32) []float32
|
||||||
|
}
|
||||||
@@ -1,15 +1,15 @@
|
|||||||
package llm
|
package ggml
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/util/bufioutil"
|
"github.com/ollama/ollama/fs/util/bufioutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GGML struct {
|
type GGML struct {
|
||||||
@@ -19,121 +19,163 @@ type GGML struct {
|
|||||||
|
|
||||||
type model interface {
|
type model interface {
|
||||||
KV() KV
|
KV() KV
|
||||||
Tensors() *Tensors
|
Tensors() Tensors
|
||||||
}
|
}
|
||||||
|
|
||||||
type KV map[string]any
|
type KV map[string]any
|
||||||
|
|
||||||
func (kv KV) u64(key string) uint64 {
|
|
||||||
switch v := kv[key].(type) {
|
|
||||||
case uint64:
|
|
||||||
return v
|
|
||||||
case uint32:
|
|
||||||
return uint64(v)
|
|
||||||
case float64:
|
|
||||||
return uint64(v)
|
|
||||||
default:
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (kv KV) Architecture() string {
|
func (kv KV) Architecture() string {
|
||||||
if s, ok := kv["general.architecture"].(string); ok {
|
return kv.String("general.architecture", "unknown")
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
return "unknown"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) Kind() string {
|
func (kv KV) Kind() string {
|
||||||
if s, ok := kv["general.type"].(string); ok {
|
return kv.String("general.type", "unknown")
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
return "unknown"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) ParameterCount() uint64 {
|
func (kv KV) ParameterCount() uint64 {
|
||||||
return kv.u64("general.parameter_count")
|
return keyValue(kv, "general.parameter_count", uint64(0))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) FileType() fileType {
|
func (kv KV) FileType() fileType {
|
||||||
if u64 := kv.u64("general.file_type"); u64 > 0 {
|
if t := kv.Uint("general.file_type"); t > 0 {
|
||||||
return fileType(uint32(u64))
|
return fileType(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
return fileTypeUnknown
|
return fileTypeUnknown
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) BlockCount() uint64 {
|
func (kv KV) BlockCount() uint64 {
|
||||||
return kv.u64(fmt.Sprintf("%s.block_count", kv.Architecture()))
|
return uint64(kv.Uint("block_count"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) EmbeddingLength() uint64 {
|
||||||
|
return uint64(kv.Uint("embedding_length"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) HeadCount() uint64 {
|
func (kv KV) HeadCount() uint64 {
|
||||||
return kv.u64(fmt.Sprintf("%s.attention.head_count", kv.Architecture()))
|
return uint64(kv.Uint("attention.head_count"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) HeadCountKV() uint64 {
|
func (kv KV) HeadCountKV() uint64 {
|
||||||
if headCountKV := kv.u64(fmt.Sprintf("%s.attention.head_count_kv", kv.Architecture())); headCountKV > 0 {
|
return uint64(kv.Uint("attention.head_count_kv", 1))
|
||||||
return headCountKV
|
|
||||||
}
|
|
||||||
|
|
||||||
return 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) EmbeddingHeadCount() uint64 {
|
func (kv KV) EmbeddingHeadCount() uint64 {
|
||||||
if heads := kv.HeadCount(); heads > 0 {
|
if heads := kv.HeadCount(); heads > 0 {
|
||||||
return kv.EmbeddingLength() / kv.HeadCount()
|
return kv.EmbeddingLength() / heads
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) EmbeddingHeadCountK() uint64 {
|
func (kv KV) EmbeddingHeadCountK() uint64 {
|
||||||
if k := kv.u64(fmt.Sprintf("%s.attention.key_length", kv.Architecture())); k > 0 {
|
return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCount())))
|
||||||
return k
|
|
||||||
}
|
|
||||||
|
|
||||||
return kv.EmbeddingHeadCount()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) EmbeddingHeadCountV() uint64 {
|
func (kv KV) EmbeddingHeadCountV() uint64 {
|
||||||
if v := kv.u64(fmt.Sprintf("%s.attention.value_length", kv.Architecture())); v > 0 {
|
return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCount())))
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
return kv.EmbeddingHeadCount()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) GQA() uint64 {
|
func (kv KV) GQA() uint64 {
|
||||||
return kv.HeadCount() / kv.HeadCountKV()
|
return kv.HeadCount() / kv.HeadCountKV()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) EmbeddingLength() uint64 {
|
|
||||||
return kv.u64(fmt.Sprintf("%s.embedding_length", kv.Architecture()))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (kv KV) ContextLength() uint64 {
|
func (kv KV) ContextLength() uint64 {
|
||||||
return kv.u64(fmt.Sprintf("%s.context_length", kv.Architecture()))
|
return uint64(kv.Uint("context_length"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) ChatTemplate() string {
|
func (kv KV) ChatTemplate() string {
|
||||||
s, _ := kv["tokenizer.chat_template"].(string)
|
return kv.String("tokenizer.chat_template")
|
||||||
return s
|
}
|
||||||
|
|
||||||
|
func (kv KV) String(key string, defaultValue ...string) string {
|
||||||
|
return keyValue(kv, key, append(defaultValue, "")...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
|
||||||
|
return keyValue(kv, key, append(defaultValue, 0)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) Float(key string, defaultValue ...float32) float32 {
|
||||||
|
return keyValue(kv, key, append(defaultValue, 0)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) Bool(key string, defaultValue ...bool) bool {
|
||||||
|
return keyValue(kv, key, append(defaultValue, false)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
||||||
|
return keyValue(kv, key, &array[string]{values: append(defaultValue, []string(nil))[0]}).values
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
|
||||||
|
return keyValue(kv, key, &array[int32]{values: append(defaultValue, []int32(nil))[0]}).values
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
|
||||||
|
return keyValue(kv, key, &array[uint32]{values: append(defaultValue, []uint32(nil))[0]}).values
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
|
||||||
|
return keyValue(kv, key, &array[float32]{values: append(defaultValue, []float32(nil))[0]}).values
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) OllamaEngineRequired() bool {
|
||||||
|
return slices.Contains([]string{
|
||||||
|
"gemma3",
|
||||||
|
"mistral3",
|
||||||
|
"llama4",
|
||||||
|
}, kv.Architecture())
|
||||||
|
}
|
||||||
|
|
||||||
|
type valueTypes interface {
|
||||||
|
uint8 | int8 | uint16 | int16 |
|
||||||
|
uint32 | int32 | uint64 | int64 |
|
||||||
|
string | float32 | float64 | bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type arrayValueTypes interface {
|
||||||
|
*array[uint8] | *array[int8] | *array[uint16] | *array[int16] |
|
||||||
|
*array[uint32] | *array[int32] | *array[uint64] | *array[int64] |
|
||||||
|
*array[string] | *array[float32] | *array[float64] | *array[bool]
|
||||||
|
}
|
||||||
|
|
||||||
|
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) T {
|
||||||
|
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
|
||||||
|
key = kv.Architecture() + "." + key
|
||||||
|
}
|
||||||
|
|
||||||
|
if val, ok := kv[key]; ok {
|
||||||
|
return val.(T)
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Warn("key not found", "key", key, "default", defaultValue[0])
|
||||||
|
return defaultValue[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
type Tensors struct {
|
type Tensors struct {
|
||||||
Items []*Tensor
|
items []*Tensor
|
||||||
Offset uint64
|
Offset uint64
|
||||||
|
|
||||||
layers map[string]Layer
|
|
||||||
layersOnce sync.Once
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts *Tensors) Layers() map[string]Layer {
|
func (s Tensors) Items(prefix ...string) []*Tensor {
|
||||||
ts.layersOnce.Do(func() {
|
if len(prefix) == 0 {
|
||||||
ts.layers = make(map[string]Layer)
|
return s.items
|
||||||
for _, t := range ts.Items {
|
}
|
||||||
|
|
||||||
|
var items []*Tensor
|
||||||
|
for _, t := range s.items {
|
||||||
|
if strings.HasPrefix(t.Name, prefix[0]) {
|
||||||
|
items = append(items, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return items
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensors) GroupLayers() map[string]Layer {
|
||||||
|
layers := make(map[string]Layer)
|
||||||
|
for _, t := range ts.items {
|
||||||
parts := strings.Split(t.Name, ".")
|
parts := strings.Split(t.Name, ".")
|
||||||
if index := slices.IndexFunc(parts, func(s string) bool { return s == "blk" || s == "mm" }); index != -1 {
|
if index := slices.IndexFunc(parts, func(s string) bool { return s == "blk" || s == "mm" }); index != -1 {
|
||||||
if len(parts) > index+2 {
|
if len(parts) > index+2 {
|
||||||
@@ -144,20 +186,19 @@ func (ts *Tensors) Layers() map[string]Layer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := ts.layers[parts[0]]; !ok {
|
if _, ok := layers[parts[0]]; !ok {
|
||||||
ts.layers[parts[0]] = make(Layer)
|
layers[parts[0]] = make(Layer)
|
||||||
}
|
}
|
||||||
|
|
||||||
ts.layers[parts[0]][strings.Join(parts[1:], ".")] = t
|
layers[parts[0]][strings.Join(parts[1:], ".")] = t
|
||||||
}
|
}
|
||||||
})
|
|
||||||
|
|
||||||
return ts.layers
|
return layers
|
||||||
}
|
}
|
||||||
|
|
||||||
type Layer map[string]*Tensor
|
type Layer map[string]*Tensor
|
||||||
|
|
||||||
func (l Layer) size() (size uint64) {
|
func (l Layer) Size() (size uint64) {
|
||||||
for _, t := range l {
|
for _, t := range l {
|
||||||
size += t.Size()
|
size += t.Size()
|
||||||
}
|
}
|
||||||
@@ -186,11 +227,26 @@ func (t Tensor) block() (n int) {
|
|||||||
|
|
||||||
func (t Tensor) blockSize() uint64 {
|
func (t Tensor) blockSize() uint64 {
|
||||||
switch t.Kind {
|
switch t.Kind {
|
||||||
case 0, 1, 24, 25, 26, 27, 28, 30: // F32, F16, I8, I16, I32, I64, F64, BF16
|
case
|
||||||
|
0, // F32
|
||||||
|
1, // F16
|
||||||
|
24, // I8
|
||||||
|
25, // I16
|
||||||
|
26, // I32
|
||||||
|
27, // I64
|
||||||
|
28, // F64
|
||||||
|
30: // BF16
|
||||||
return 1
|
return 1
|
||||||
case 2, 3, 4, 5, 6, 7, 8, 9, 20: // Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, Q8_1, IQ4_NL
|
case
|
||||||
|
2, // Q4_0
|
||||||
|
3, // Q4_1
|
||||||
|
6, // Q5_0
|
||||||
|
7, // Q5_1
|
||||||
|
8, // Q8_0
|
||||||
|
9, // Q8_1
|
||||||
|
20: // IQ4_NL
|
||||||
return 32
|
return 32
|
||||||
default: // All others
|
default:
|
||||||
return 256
|
return 256
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -214,7 +270,7 @@ func (t Tensor) typeSize() uint64 {
|
|||||||
case 8: // Q8_0
|
case 8: // Q8_0
|
||||||
return 2 + blockSize
|
return 2 + blockSize
|
||||||
case 9: // Q8_1
|
case 9: // Q8_1
|
||||||
return 4 + 4 + blockSize
|
return 2 + 2 + blockSize
|
||||||
case 10: // Q2_K
|
case 10: // Q2_K
|
||||||
return blockSize/16 + blockSize/4 + 2 + 2
|
return blockSize/16 + blockSize/4 + 2 + 2
|
||||||
case 11: // Q3_K
|
case 11: // Q3_K
|
||||||
@@ -226,7 +282,7 @@ func (t Tensor) typeSize() uint64 {
|
|||||||
case 14: // Q6_K
|
case 14: // Q6_K
|
||||||
return blockSize/2 + blockSize/4 + blockSize/16 + 2
|
return blockSize/2 + blockSize/4 + blockSize/16 + 2
|
||||||
case 15: // Q8_K
|
case 15: // Q8_K
|
||||||
return 2 + blockSize + 2*blockSize/16
|
return 4 + blockSize + 2*blockSize/16
|
||||||
case 16: // IQ2_XXS
|
case 16: // IQ2_XXS
|
||||||
return 2 + 2*blockSize/8
|
return 2 + 2*blockSize/8
|
||||||
case 17: // IQ2_XS
|
case 17: // IQ2_XS
|
||||||
@@ -274,6 +330,10 @@ func (t Tensor) Size() uint64 {
|
|||||||
return t.parameters() * t.typeSize() / t.blockSize()
|
return t.parameters() * t.typeSize() / t.blockSize()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t Tensor) Type() string {
|
||||||
|
return fileType(t.Kind).String()
|
||||||
|
}
|
||||||
|
|
||||||
type container interface {
|
type container interface {
|
||||||
Name() string
|
Name() string
|
||||||
Decode(io.ReadSeeker) (model, error)
|
Decode(io.ReadSeeker) (model, error)
|
||||||
@@ -295,7 +355,7 @@ const (
|
|||||||
|
|
||||||
var ErrUnsupportedFormat = errors.New("unsupported model format")
|
var ErrUnsupportedFormat = errors.New("unsupported model format")
|
||||||
|
|
||||||
func DetectGGMLType(b []byte) string {
|
func DetectContentType(b []byte) string {
|
||||||
switch binary.LittleEndian.Uint32(b[:4]) {
|
switch binary.LittleEndian.Uint32(b[:4]) {
|
||||||
case FILE_MAGIC_GGML:
|
case FILE_MAGIC_GGML:
|
||||||
return "ggml"
|
return "ggml"
|
||||||
@@ -312,16 +372,11 @@ func DetectGGMLType(b []byte) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// DecodeGGML decodes a GGML model from the given reader.
|
// Decode decodes a GGML model from the given reader.
|
||||||
//
|
//
|
||||||
// It collects array values for arrays with a size less than or equal to
|
// It collects array values for arrays with a size less than or equal to
|
||||||
// maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
|
// maxArraySize. If the maxArraySize is negative, all arrays are collected.
|
||||||
// the maxArraySize is negative, all arrays are collected.
|
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
||||||
func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
|
||||||
if maxArraySize == 0 {
|
|
||||||
maxArraySize = 1024
|
|
||||||
}
|
|
||||||
|
|
||||||
rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
|
rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
|
||||||
|
|
||||||
var magic uint32
|
var magic uint32
|
||||||
@@ -331,10 +386,6 @@ func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
|||||||
|
|
||||||
var c container
|
var c container
|
||||||
switch magic {
|
switch magic {
|
||||||
case FILE_MAGIC_GGML, FILE_MAGIC_GGMF, FILE_MAGIC_GGJT:
|
|
||||||
return nil, 0, ErrUnsupportedFormat
|
|
||||||
case FILE_MAGIC_GGLA:
|
|
||||||
c = &containerGGLA{}
|
|
||||||
case FILE_MAGIC_GGUF_LE:
|
case FILE_MAGIC_GGUF_LE:
|
||||||
c = &containerGGUF{ByteOrder: binary.LittleEndian, maxArraySize: maxArraySize}
|
c = &containerGGUF{ByteOrder: binary.LittleEndian, maxArraySize: maxArraySize}
|
||||||
case FILE_MAGIC_GGUF_BE:
|
case FILE_MAGIC_GGUF_BE:
|
||||||
@@ -360,23 +411,26 @@ func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
|||||||
}, offset, nil
|
}, offset, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
|
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||||
embedding := llm.KV().EmbeddingLength()
|
embedding := f.KV().EmbeddingLength()
|
||||||
heads := llm.KV().HeadCount()
|
heads := f.KV().HeadCount()
|
||||||
headsKV := llm.KV().HeadCountKV()
|
headsKV := f.KV().HeadCountKV()
|
||||||
vocab := uint64(llm.KV()["tokenizer.ggml.tokens"].(*array).size)
|
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size)
|
||||||
|
|
||||||
embeddingHeads := llm.KV().EmbeddingHeadCount()
|
embeddingHeads := f.KV().EmbeddingHeadCount()
|
||||||
embeddingHeadsK := llm.KV().EmbeddingHeadCountK()
|
embeddingHeadsK := f.KV().EmbeddingHeadCountK()
|
||||||
embeddingHeadsV := llm.KV().EmbeddingHeadCountV()
|
embeddingHeadsV := f.KV().EmbeddingHeadCountV()
|
||||||
|
|
||||||
layers := llm.Tensors().Layers()
|
layers := f.Tensors().GroupLayers()
|
||||||
|
|
||||||
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
||||||
kv = uint64(float64(context*llm.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
kv = make([]uint64, f.KV().BlockCount())
|
||||||
|
for i := range kv {
|
||||||
|
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||||
|
}
|
||||||
|
|
||||||
switch llm.KV().Architecture() {
|
switch f.KV().Architecture() {
|
||||||
case "llama":
|
case "llama", "llama4":
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
4*batch*(1+4*embedding+context*(1+heads)),
|
4*batch*(1+4*embedding+context*(1+heads)),
|
||||||
4*batch*(embedding+vocab),
|
4*batch*(embedding+vocab),
|
||||||
@@ -390,7 +444,7 @@ func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partia
|
|||||||
|
|
||||||
if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
|
if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
|
||||||
// mixtral 8x22b
|
// mixtral 8x22b
|
||||||
ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
|
ff := uint64(f.KV().Uint("feed_forward_length"))
|
||||||
partialOffload = max(
|
partialOffload = max(
|
||||||
3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
|
3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
|
||||||
4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
|
4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
|
||||||
@@ -407,16 +461,14 @@ func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partia
|
|||||||
case "mllama":
|
case "mllama":
|
||||||
var visionTokens, tiles uint64 = 1601, 4
|
var visionTokens, tiles uint64 = 1601, 4
|
||||||
|
|
||||||
if crossAttentionLayers, ok := llm.KV()["mllama.attention.cross_attention_layers"].(*array); ok {
|
crossAttentionLayers := f.KV().Ints("attention.cross_attention_layers")
|
||||||
kv = headsKV *
|
for i := range kv {
|
||||||
(embeddingHeadsK + embeddingHeadsV) * // one for K, one for V
|
if slices.Contains(crossAttentionLayers, int32(i)) {
|
||||||
(2* // sizeof(float16)
|
kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) *
|
||||||
(llm.KV().BlockCount()-uint64(crossAttentionLayers.size))* // num non-cross attention layers
|
4 * // sizeof(float32)
|
||||||
context +
|
visionTokens *
|
||||||
4* // sizeof(float32)
|
tiles
|
||||||
uint64(crossAttentionLayers.size)* // num cross attention layers
|
}
|
||||||
visionTokens*
|
|
||||||
tiles)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
@@ -426,7 +478,7 @@ func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partia
|
|||||||
)
|
)
|
||||||
|
|
||||||
var ropeFreqsCount uint64
|
var ropeFreqsCount uint64
|
||||||
if ropeFreqs, ok := llm.Tensors().Layers()["rope_freqs"]; ok {
|
if ropeFreqs, ok := f.Tensors().GroupLayers()["rope_freqs"]; ok {
|
||||||
if ropeFreqsWeights, ok := ropeFreqs["weights"]; ok {
|
if ropeFreqsWeights, ok := ropeFreqs["weights"]; ok {
|
||||||
ropeFreqsCount = ropeFreqsWeights.parameters()
|
ropeFreqsCount = ropeFreqsWeights.parameters()
|
||||||
}
|
}
|
||||||
@@ -440,7 +492,7 @@ func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partia
|
|||||||
// vocab graph
|
// vocab graph
|
||||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||||
)
|
)
|
||||||
case "gemma", "gemma2":
|
case "gemma", "gemma2", "gemma3":
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
4*batch*(embedding+vocab),
|
4*batch*(embedding+vocab),
|
||||||
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
|
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
|
||||||
@@ -452,6 +504,20 @@ func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partia
|
|||||||
4*embeddingHeadsK*context*8+
|
4*embeddingHeadsK*context*8+
|
||||||
embedding*embeddingHeadsK*heads*9/16,
|
embedding*embeddingHeadsK*heads*9/16,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
|
||||||
|
// engine. Gemma3 always uses the Ollama engine.
|
||||||
|
if f.KV().Architecture() == "gemma3" {
|
||||||
|
const gemma3GlobalCacheCount = 6
|
||||||
|
slidingWindow := (uint64(numParallel) * uint64(f.KV().Uint("attention.sliding_window"))) + batch
|
||||||
|
for i := range kv {
|
||||||
|
// Every 6th layer is a global layer, which is the full context size that has already been set. The other
|
||||||
|
// layers are the smaller local (sliding) layers.
|
||||||
|
if (i+1)%gemma3GlobalCacheCount != 0 {
|
||||||
|
kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
case "command-r":
|
case "command-r":
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
4*batch*(embedding+vocab),
|
4*batch*(embedding+vocab),
|
||||||
@@ -529,22 +595,74 @@ func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partia
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
|
||||||
|
if llm.KV().Uint("vision.block_count") == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, layer := range llm.Tensors().GroupLayers() {
|
||||||
|
if name == "v" || strings.HasPrefix(name, "v.") {
|
||||||
|
for _, tensor := range layer {
|
||||||
|
weights += tensor.Size()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
imageSize := uint64(llm.KV().Uint("vision.image_size"))
|
||||||
|
patchSize := uint64(llm.KV().Uint("vision.patch_size"))
|
||||||
|
if patchSize == 0 {
|
||||||
|
slog.Warn("unknown patch size for vision model")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
numChannels := uint64(llm.KV().Uint("vision.num_channels"))
|
||||||
|
|
||||||
|
numPatches := (imageSize / patchSize) * (imageSize / patchSize)
|
||||||
|
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
|
||||||
|
numPatches++
|
||||||
|
}
|
||||||
|
|
||||||
|
headCount := uint64(llm.KV().Uint("vision.attention.head_count"))
|
||||||
|
embeddingLength := uint64(llm.KV().Uint("vision.embedding_length"))
|
||||||
|
|
||||||
|
switch llm.KV().Architecture() {
|
||||||
|
case "mllama":
|
||||||
|
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
|
||||||
|
|
||||||
|
maxNumTiles := uint64(llm.KV().Uint("vision.max_num_tiles"))
|
||||||
|
|
||||||
|
graphSize = 4 * (8 +
|
||||||
|
imageSize*imageSize*numChannels*maxNumTiles +
|
||||||
|
embeddingLength*numPatches*maxNumTiles +
|
||||||
|
9*embeddingLength*numPaddedPatches*maxNumTiles +
|
||||||
|
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
|
||||||
|
case "gemma3", "mistral3":
|
||||||
|
graphSize = 4 * (imageSize*imageSize*numChannels +
|
||||||
|
embeddingLength*patchSize +
|
||||||
|
numPatches*numPatches*headCount)
|
||||||
|
case "llama4":
|
||||||
|
// vision graph is computed independently in the same schedule
|
||||||
|
// and is negligible compared to the worst case text graph
|
||||||
|
}
|
||||||
|
|
||||||
|
return weights, graphSize
|
||||||
|
}
|
||||||
|
|
||||||
// SupportsKVCacheType checks if the requested cache type is supported
|
// SupportsKVCacheType checks if the requested cache type is supported
|
||||||
func (ggml GGML) SupportsKVCacheType(cacheType string) bool {
|
func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
||||||
validKVCacheTypes := []string{"f16", "q8_0", "q4_0"}
|
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)
|
||||||
return slices.Contains(validKVCacheTypes, cacheType)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SupportsFlashAttention checks if the model supports flash attention
|
// SupportsFlashAttention checks if the model supports flash attention
|
||||||
func (ggml GGML) SupportsFlashAttention() bool {
|
func (f GGML) SupportsFlashAttention() bool {
|
||||||
_, isEmbedding := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]
|
_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]
|
||||||
if isEmbedding {
|
if isEmbedding {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check head counts match and are non-zero
|
// Check head counts match and are non-zero
|
||||||
headCountK := ggml.KV().EmbeddingHeadCountK()
|
headCountK := f.KV().EmbeddingHeadCountK()
|
||||||
headCountV := ggml.KV().EmbeddingHeadCountV()
|
headCountV := f.KV().EmbeddingHeadCountV()
|
||||||
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
|
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
|
||||||
}
|
}
|
||||||
|
|
||||||
271
fs/ggml/ggml_test.go
Normal file
271
fs/ggml/ggml_test.go
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
package ggml
|
||||||
|
|
||||||
|
import (
|
||||||
|
"maps"
|
||||||
|
"math"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTensorLayers(t *testing.T) {
|
||||||
|
tensors := make(map[string]*Tensor)
|
||||||
|
for _, name := range []string{
|
||||||
|
"token_embd.weight",
|
||||||
|
"blk.0.attn_k.weight",
|
||||||
|
"blk.0.attn_output.weight",
|
||||||
|
"blk.0.attn_q.weight",
|
||||||
|
"blk.0.attn_v.weight",
|
||||||
|
"blk.0.attn_norm.weight",
|
||||||
|
"blk.0.ffn_down.weight",
|
||||||
|
"blk.0.ffn_gate.weight",
|
||||||
|
"blk.0.ffn_up.weight",
|
||||||
|
"blk.0.ffn_norm.weight",
|
||||||
|
"output_norm.weight",
|
||||||
|
"mm.0.bias",
|
||||||
|
"mm.0.weight",
|
||||||
|
"v.blk.0.attn_k.weight",
|
||||||
|
"v.blk.0.attn_output.weight",
|
||||||
|
"v.blk.0.attn_q.weight",
|
||||||
|
"v.blk.0.attn_v.weight",
|
||||||
|
"v.blk.0.attn_norm.weight",
|
||||||
|
"v.blk.0.ffn_down.weight",
|
||||||
|
"v.blk.0.ffn_gate.weight",
|
||||||
|
"v.blk.0.ffn_up.weight",
|
||||||
|
"v.blk.0.ffn_norm.weight",
|
||||||
|
"v.patch_embd.weight",
|
||||||
|
"v.position_embd.gate",
|
||||||
|
"v.position_embd.weight",
|
||||||
|
} {
|
||||||
|
tensors[name] = &Tensor{Name: name}
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
items []*Tensor
|
||||||
|
want map[string]Layer
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "text",
|
||||||
|
items: slices.Collect(func(yield func(*Tensor) bool) {
|
||||||
|
for k, v := range tensors {
|
||||||
|
if !strings.HasPrefix(k, "mm.") && !strings.HasPrefix(k, "v.") {
|
||||||
|
if !yield(v) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
want: map[string]Layer{
|
||||||
|
"blk.0": {
|
||||||
|
"attn_k.weight": tensors["blk.0.attn_k.weight"],
|
||||||
|
"attn_q.weight": tensors["blk.0.attn_q.weight"],
|
||||||
|
"attn_v.weight": tensors["blk.0.attn_v.weight"],
|
||||||
|
"attn_output.weight": tensors["blk.0.attn_output.weight"],
|
||||||
|
"attn_norm.weight": tensors["blk.0.attn_norm.weight"],
|
||||||
|
"ffn_down.weight": tensors["blk.0.ffn_down.weight"],
|
||||||
|
"ffn_gate.weight": tensors["blk.0.ffn_gate.weight"],
|
||||||
|
"ffn_up.weight": tensors["blk.0.ffn_up.weight"],
|
||||||
|
"ffn_norm.weight": tensors["blk.0.ffn_norm.weight"],
|
||||||
|
},
|
||||||
|
"token_embd": {"weight": tensors["token_embd.weight"]},
|
||||||
|
"output_norm": {"weight": tensors["output_norm.weight"]},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "vision",
|
||||||
|
items: slices.Collect(func(yield func(*Tensor) bool) {
|
||||||
|
for k, v := range tensors {
|
||||||
|
if strings.HasPrefix(k, "mm.") || strings.HasPrefix(k, "v.") {
|
||||||
|
if !yield(v) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
want: map[string]Layer{
|
||||||
|
"mm.0": {
|
||||||
|
"bias": tensors["mm.0.bias"],
|
||||||
|
"weight": tensors["mm.0.weight"],
|
||||||
|
},
|
||||||
|
"v.blk.0": {
|
||||||
|
"attn_k.weight": tensors["v.blk.0.attn_k.weight"],
|
||||||
|
"attn_q.weight": tensors["v.blk.0.attn_q.weight"],
|
||||||
|
"attn_v.weight": tensors["v.blk.0.attn_v.weight"],
|
||||||
|
"attn_output.weight": tensors["v.blk.0.attn_output.weight"],
|
||||||
|
"attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
|
||||||
|
"ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
|
||||||
|
"ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
|
||||||
|
"ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
|
||||||
|
"ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
|
||||||
|
},
|
||||||
|
"v": {
|
||||||
|
"patch_embd.weight": tensors["v.patch_embd.weight"],
|
||||||
|
"position_embd.gate": tensors["v.position_embd.gate"],
|
||||||
|
"position_embd.weight": tensors["v.position_embd.weight"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "vision and text",
|
||||||
|
items: slices.Collect(maps.Values(tensors)),
|
||||||
|
want: map[string]Layer{
|
||||||
|
"blk.0": {
|
||||||
|
"attn_k.weight": tensors["blk.0.attn_k.weight"],
|
||||||
|
"attn_q.weight": tensors["blk.0.attn_q.weight"],
|
||||||
|
"attn_v.weight": tensors["blk.0.attn_v.weight"],
|
||||||
|
"attn_output.weight": tensors["blk.0.attn_output.weight"],
|
||||||
|
"attn_norm.weight": tensors["blk.0.attn_norm.weight"],
|
||||||
|
"ffn_down.weight": tensors["blk.0.ffn_down.weight"],
|
||||||
|
"ffn_gate.weight": tensors["blk.0.ffn_gate.weight"],
|
||||||
|
"ffn_up.weight": tensors["blk.0.ffn_up.weight"],
|
||||||
|
"ffn_norm.weight": tensors["blk.0.ffn_norm.weight"],
|
||||||
|
},
|
||||||
|
"token_embd": {"weight": tensors["token_embd.weight"]},
|
||||||
|
"output_norm": {"weight": tensors["output_norm.weight"]},
|
||||||
|
"mm.0": {
|
||||||
|
"bias": tensors["mm.0.bias"],
|
||||||
|
"weight": tensors["mm.0.weight"],
|
||||||
|
},
|
||||||
|
"v.blk.0": {
|
||||||
|
"attn_k.weight": tensors["v.blk.0.attn_k.weight"],
|
||||||
|
"attn_q.weight": tensors["v.blk.0.attn_q.weight"],
|
||||||
|
"attn_v.weight": tensors["v.blk.0.attn_v.weight"],
|
||||||
|
"attn_output.weight": tensors["v.blk.0.attn_output.weight"],
|
||||||
|
"attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
|
||||||
|
"ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
|
||||||
|
"ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
|
||||||
|
"ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
|
||||||
|
"ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
|
||||||
|
},
|
||||||
|
"v": {
|
||||||
|
"patch_embd.weight": tensors["v.patch_embd.weight"],
|
||||||
|
"position_embd.gate": tensors["v.position_embd.gate"],
|
||||||
|
"position_embd.weight": tensors["v.position_embd.weight"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := Tensors{items: tt.items}.GroupLayers()
|
||||||
|
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||||
|
t.Errorf("unexpected layers (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/blob/a82c9e7c23ef6db48cebfa194dc9cebbc4ac3552/ggml/src/ggml.c#L572
|
||||||
|
func TestTensorTypes(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
kind uint32
|
||||||
|
blockSize uint64
|
||||||
|
typeSize uint64
|
||||||
|
}{
|
||||||
|
{0, 1, 4},
|
||||||
|
{1, 1, 2},
|
||||||
|
{2, 32, 18},
|
||||||
|
{3, 32, 20},
|
||||||
|
{6, 32, 22},
|
||||||
|
{7, 32, 24},
|
||||||
|
{8, 32, 34},
|
||||||
|
{9, 32, 36},
|
||||||
|
{10, 256, 84},
|
||||||
|
{11, 256, 110},
|
||||||
|
{12, 256, 144},
|
||||||
|
{13, 256, 176},
|
||||||
|
{14, 256, 210},
|
||||||
|
{15, 256, 292},
|
||||||
|
{16, 256, 66},
|
||||||
|
{17, 256, 74},
|
||||||
|
{18, 256, 98},
|
||||||
|
{19, 256, 50},
|
||||||
|
{20, 32, 18},
|
||||||
|
{21, 256, 110},
|
||||||
|
{22, 256, 82},
|
||||||
|
{23, 256, 136},
|
||||||
|
{24, 1, 1},
|
||||||
|
{25, 1, 2},
|
||||||
|
{26, 1, 4},
|
||||||
|
{27, 1, 8},
|
||||||
|
{28, 1, 8},
|
||||||
|
{29, 256, 56},
|
||||||
|
{30, 1, 2},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(strconv.Itoa(int(tt.kind)), func(t *testing.T) {
|
||||||
|
tensor := Tensor{Kind: tt.kind}
|
||||||
|
if tensor.blockSize() != tt.blockSize {
|
||||||
|
t.Errorf("unexpected block size: got=%d want=%d", tensor.blockSize(), tt.blockSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tensor.typeSize() != tt.typeSize {
|
||||||
|
t.Errorf("unexpected type size: got=%d want=%d", tensor.typeSize(), tt.typeSize)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyValue(t *testing.T) {
|
||||||
|
kv := KV{
|
||||||
|
"general.architecture": "test",
|
||||||
|
"test.strings": &array[string]{size: 3, values: []string{"a", "b", "c"}},
|
||||||
|
"test.float32s": &array[float32]{size: 3, values: []float32{1.0, 2.0, 3.0}},
|
||||||
|
"test.int32s": &array[int32]{size: 3, values: []int32{1, 2, 3}},
|
||||||
|
"test.uint32s": &array[uint32]{size: 3, values: []uint32{1, 2, 3}},
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(kv.Strings("strings"), []string{"a", "b", "c"}); diff != "" {
|
||||||
|
t.Errorf("unexpected strings (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(kv.Strings("nonexistent.strings"), []string(nil)); diff != "" {
|
||||||
|
t.Errorf("unexpected strings (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(kv.Strings("default.strings", []string{"ollama"}), []string{"ollama"}); diff != "" {
|
||||||
|
t.Errorf("unexpected strings (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(kv.Floats("float32s"), []float32{1.0, 2.0, 3.0}); diff != "" {
|
||||||
|
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(kv.Floats("nonexistent.float32s"), []float32(nil)); diff != "" {
|
||||||
|
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(kv.Floats("default.float32s", []float32{math.MaxFloat32}), []float32{math.MaxFloat32}); diff != "" {
|
||||||
|
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(kv.Ints("int32s"), []int32{1, 2, 3}); diff != "" {
|
||||||
|
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(kv.Ints("nonexistent.int32s"), []int32(nil)); diff != "" {
|
||||||
|
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(kv.Ints("default.int32s", []int32{math.MaxInt32}), []int32{math.MaxInt32}); diff != "" {
|
||||||
|
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(kv.Uints("uint32s"), []uint32{1, 2, 3}); diff != "" {
|
||||||
|
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(kv.Uints("nonexistent.uint32s"), []uint32(nil)); diff != "" {
|
||||||
|
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(kv.Uints("default.uint32s", []uint32{math.MaxUint32}), []uint32{math.MaxUint32}); diff != "" {
|
||||||
|
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package llm
|
package ggml
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@@ -8,10 +8,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"maps"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/exp/maps"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type containerGGUF struct {
|
type containerGGUF struct {
|
||||||
@@ -37,10 +36,6 @@ type containerGGUF struct {
|
|||||||
maxArraySize int
|
maxArraySize int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *containerGGUF) canCollectArray(size int) bool {
|
|
||||||
return c.maxArraySize < 0 || size <= c.maxArraySize
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *containerGGUF) Name() string {
|
func (c *containerGGUF) Name() string {
|
||||||
return "gguf"
|
return "gguf"
|
||||||
}
|
}
|
||||||
@@ -110,9 +105,9 @@ func (llm *gguf) KV() KV {
|
|||||||
return llm.kv
|
return llm.kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *gguf) Tensors() *Tensors {
|
func (llm *gguf) Tensors() Tensors {
|
||||||
return &Tensors{
|
return Tensors{
|
||||||
Items: llm.tensors,
|
items: llm.tensors,
|
||||||
Offset: llm.tensorOffset,
|
Offset: llm.tensorOffset,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -236,10 +231,7 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
|
|||||||
// patch KV with parameter count
|
// patch KV with parameter count
|
||||||
llm.kv["general.parameter_count"] = llm.parameters
|
llm.kv["general.parameter_count"] = llm.parameters
|
||||||
|
|
||||||
alignment, ok := llm.kv["general.alignment"].(uint32)
|
alignment := llm.kv.Uint("general.alignment", 32)
|
||||||
if !ok {
|
|
||||||
alignment = 32
|
|
||||||
}
|
|
||||||
|
|
||||||
offset, err := rs.Seek(0, io.SeekCurrent)
|
offset, err := rs.Seek(0, io.SeekCurrent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -299,6 +291,23 @@ func readGGUFV1String(llm *gguf, r io.Reader) (string, error) {
|
|||||||
return b.String(), nil
|
return b.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func readGGUFV1StringsData(llm *gguf, r io.Reader, a *array[string]) (any, error) {
|
||||||
|
for i := range a.size {
|
||||||
|
if a.values != nil {
|
||||||
|
e, err := readGGUFV1String(llm, r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
a.values[i] = e
|
||||||
|
} else {
|
||||||
|
discardGGUFString(llm, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return a, nil
|
||||||
|
}
|
||||||
|
|
||||||
func discardGGUFString(llm *gguf, r io.Reader) error {
|
func discardGGUFString(llm *gguf, r io.Reader) error {
|
||||||
buf := llm.scratch[:8]
|
buf := llm.scratch[:8]
|
||||||
_, err := io.ReadFull(r, buf)
|
_, err := io.ReadFull(r, buf)
|
||||||
@@ -356,78 +365,44 @@ func writeGGUFString(w io.Writer, s string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
type array struct {
|
func readGGUFStringsData(llm *gguf, r io.Reader, a *array[string]) (any, error) {
|
||||||
size int
|
for i := range a.size {
|
||||||
values []any
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *array) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(a.values)
|
|
||||||
}
|
|
||||||
|
|
||||||
func readGGUFV1Array(llm *gguf, r io.Reader) (*array, error) {
|
|
||||||
t, err := readGGUF[uint32](llm, r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
n, err := readGGUF[uint32](llm, r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
a := &array{size: int(n)}
|
|
||||||
if llm.canCollectArray(int(n)) {
|
|
||||||
a.values = make([]any, 0, int(n))
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range n {
|
|
||||||
var e any
|
|
||||||
switch t {
|
|
||||||
case ggufTypeUint8:
|
|
||||||
e, err = readGGUF[uint8](llm, r)
|
|
||||||
case ggufTypeInt8:
|
|
||||||
e, err = readGGUF[int8](llm, r)
|
|
||||||
case ggufTypeUint16:
|
|
||||||
e, err = readGGUF[uint16](llm, r)
|
|
||||||
case ggufTypeInt16:
|
|
||||||
e, err = readGGUF[int16](llm, r)
|
|
||||||
case ggufTypeUint32:
|
|
||||||
e, err = readGGUF[uint32](llm, r)
|
|
||||||
case ggufTypeInt32:
|
|
||||||
e, err = readGGUF[int32](llm, r)
|
|
||||||
case ggufTypeUint64:
|
|
||||||
e, err = readGGUF[uint64](llm, r)
|
|
||||||
case ggufTypeInt64:
|
|
||||||
e, err = readGGUF[int64](llm, r)
|
|
||||||
case ggufTypeFloat32:
|
|
||||||
e, err = readGGUF[float32](llm, r)
|
|
||||||
case ggufTypeFloat64:
|
|
||||||
e, err = readGGUF[float64](llm, r)
|
|
||||||
case ggufTypeBool:
|
|
||||||
e, err = readGGUF[bool](llm, r)
|
|
||||||
case ggufTypeString:
|
|
||||||
e, err = readGGUFV1String(llm, r)
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("invalid array type: %d", t)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if a.values != nil {
|
if a.values != nil {
|
||||||
|
e, err := readGGUFString(llm, r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
a.values[i] = e
|
a.values[i] = e
|
||||||
|
} else {
|
||||||
|
discardGGUFString(llm, r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
|
type array[T any] struct {
|
||||||
if llm.Version == 1 {
|
// size is the actual size of the array
|
||||||
return readGGUFV1Array(llm, r)
|
size int
|
||||||
}
|
|
||||||
|
|
||||||
|
// values is the array of values. this is nil if the array is larger than configured maxSize
|
||||||
|
values []T
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *array[T]) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(a.values)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newArray[T any](size, maxSize int) *array[T] {
|
||||||
|
a := array[T]{size: size}
|
||||||
|
if maxSize < 0 || size <= maxSize {
|
||||||
|
a.values = make([]T, size)
|
||||||
|
}
|
||||||
|
return &a
|
||||||
|
}
|
||||||
|
|
||||||
|
func readGGUFArray(llm *gguf, r io.Reader) (any, error) {
|
||||||
t, err := readGGUF[uint32](llm, r)
|
t, err := readGGUF[uint32](llm, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -438,45 +413,55 @@ func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
a := &array{size: int(n)}
|
|
||||||
if llm.canCollectArray(int(n)) {
|
|
||||||
a.values = make([]any, int(n))
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range n {
|
|
||||||
var e any
|
|
||||||
switch t {
|
switch t {
|
||||||
case ggufTypeUint8:
|
case ggufTypeUint8:
|
||||||
e, err = readGGUF[uint8](llm, r)
|
a := newArray[uint8](int(n), llm.maxArraySize)
|
||||||
|
return readGGUFArrayData(llm, r, a)
|
||||||
case ggufTypeInt8:
|
case ggufTypeInt8:
|
||||||
e, err = readGGUF[int8](llm, r)
|
a := newArray[int8](int(n), llm.maxArraySize)
|
||||||
|
return readGGUFArrayData(llm, r, a)
|
||||||
case ggufTypeUint16:
|
case ggufTypeUint16:
|
||||||
e, err = readGGUF[uint16](llm, r)
|
a := newArray[uint16](int(n), llm.maxArraySize)
|
||||||
|
return readGGUFArrayData(llm, r, a)
|
||||||
case ggufTypeInt16:
|
case ggufTypeInt16:
|
||||||
e, err = readGGUF[int16](llm, r)
|
a := newArray[int16](int(n), llm.maxArraySize)
|
||||||
|
return readGGUFArrayData(llm, r, a)
|
||||||
case ggufTypeUint32:
|
case ggufTypeUint32:
|
||||||
e, err = readGGUF[uint32](llm, r)
|
a := newArray[uint32](int(n), llm.maxArraySize)
|
||||||
|
return readGGUFArrayData(llm, r, a)
|
||||||
case ggufTypeInt32:
|
case ggufTypeInt32:
|
||||||
e, err = readGGUF[int32](llm, r)
|
a := newArray[int32](int(n), llm.maxArraySize)
|
||||||
|
return readGGUFArrayData(llm, r, a)
|
||||||
case ggufTypeUint64:
|
case ggufTypeUint64:
|
||||||
e, err = readGGUF[uint64](llm, r)
|
a := newArray[uint64](int(n), llm.maxArraySize)
|
||||||
|
return readGGUFArrayData(llm, r, a)
|
||||||
case ggufTypeInt64:
|
case ggufTypeInt64:
|
||||||
e, err = readGGUF[int64](llm, r)
|
a := newArray[int64](int(n), llm.maxArraySize)
|
||||||
|
return readGGUFArrayData(llm, r, a)
|
||||||
case ggufTypeFloat32:
|
case ggufTypeFloat32:
|
||||||
e, err = readGGUF[float32](llm, r)
|
a := newArray[float32](int(n), llm.maxArraySize)
|
||||||
|
return readGGUFArrayData(llm, r, a)
|
||||||
case ggufTypeFloat64:
|
case ggufTypeFloat64:
|
||||||
e, err = readGGUF[float64](llm, r)
|
a := newArray[float64](int(n), llm.maxArraySize)
|
||||||
|
return readGGUFArrayData(llm, r, a)
|
||||||
case ggufTypeBool:
|
case ggufTypeBool:
|
||||||
e, err = readGGUF[bool](llm, r)
|
a := newArray[bool](int(n), llm.maxArraySize)
|
||||||
|
return readGGUFArrayData(llm, r, a)
|
||||||
case ggufTypeString:
|
case ggufTypeString:
|
||||||
if a.values != nil {
|
a := newArray[string](int(n), llm.maxArraySize)
|
||||||
e, err = readGGUFString(llm, r)
|
if llm.Version == 1 {
|
||||||
} else {
|
return readGGUFV1StringsData(llm, r, a)
|
||||||
err = discardGGUFString(llm, r)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return readGGUFStringsData(llm, r, a)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid array type: %d", t)
|
return nil, fmt.Errorf("invalid array type: %d", t)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func readGGUFArrayData[T any](llm *gguf, r io.Reader, a *array[T]) (any, error) {
|
||||||
|
for i := range a.size {
|
||||||
|
e, err := readGGUF[T](llm, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -507,6 +492,8 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
|
func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
|
||||||
|
alignment := kv.Uint("general.alignment", 32)
|
||||||
|
|
||||||
if err := binary.Write(ws, binary.LittleEndian, []byte("GGUF")); err != nil {
|
if err := binary.Write(ws, binary.LittleEndian, []byte("GGUF")); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -523,7 +510,7 @@ func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
keys := maps.Keys(kv)
|
keys := slices.Collect(maps.Keys(kv))
|
||||||
slices.Sort(keys)
|
slices.Sort(keys)
|
||||||
|
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
@@ -544,16 +531,15 @@ func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
|
|||||||
|
|
||||||
var s uint64
|
var s uint64
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
t.Offset = s
|
t.Offset = s + uint64(ggufPadding(int64(s), int64(alignment)))
|
||||||
if err := ggufWriteTensorInfo(ws, t); err != nil {
|
if err := ggufWriteTensorInfo(ws, t); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s += t.Size()
|
s += t.Size()
|
||||||
}
|
}
|
||||||
|
|
||||||
var alignment int64 = 32
|
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
if err := ggufWriteTensor(ws, t, alignment); err != nil {
|
if err := ggufWriteTensor(ws, t, int64(alignment)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -630,8 +616,8 @@ func ggufWriteTensorInfo(ws io.WriteSeeker, t Tensor) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range len(t.Shape) {
|
for _, n := range t.Shape {
|
||||||
if err := binary.Write(ws, binary.LittleEndian, t.Shape[len(t.Shape)-i-1]); err != nil {
|
if err := binary.Write(ws, binary.LittleEndian, n); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package llm
|
package ggml
|
||||||
|
|
||||||
import "fmt"
|
import "fmt"
|
||||||
|
|
||||||
@@ -98,10 +98,10 @@ func ParseFileType(s string) (fileType, error) {
|
|||||||
return fileTypeIQ3_M, nil
|
return fileTypeIQ3_M, nil
|
||||||
case "IQ2_S":
|
case "IQ2_S":
|
||||||
return fileTypeIQ2_S, nil
|
return fileTypeIQ2_S, nil
|
||||||
case "IQ4_XS":
|
|
||||||
return fileTypeIQ4_XS, nil
|
|
||||||
case "IQ2_M":
|
case "IQ2_M":
|
||||||
return fileTypeIQ2_M, nil
|
return fileTypeIQ2_M, nil
|
||||||
|
case "IQ4_XS":
|
||||||
|
return fileTypeIQ4_XS, nil
|
||||||
case "IQ1_M":
|
case "IQ1_M":
|
||||||
return fileTypeIQ1_M, nil
|
return fileTypeIQ1_M, nil
|
||||||
case "BF16":
|
case "BF16":
|
||||||
18
go.mod
18
go.mod
@@ -1,6 +1,6 @@
|
|||||||
module github.com/ollama/ollama
|
module github.com/ollama/ollama
|
||||||
|
|
||||||
go 1.23.4
|
go 1.24.0
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/containerd/console v1.0.3
|
github.com/containerd/console v1.0.3
|
||||||
@@ -11,18 +11,20 @@ require (
|
|||||||
github.com/spf13/cobra v1.7.0
|
github.com/spf13/cobra v1.7.0
|
||||||
github.com/stretchr/testify v1.9.0
|
github.com/stretchr/testify v1.9.0
|
||||||
github.com/x448/float16 v0.8.4
|
github.com/x448/float16 v0.8.4
|
||||||
golang.org/x/sync v0.10.0
|
golang.org/x/sync v0.11.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/agnivade/levenshtein v1.1.1
|
github.com/agnivade/levenshtein v1.1.1
|
||||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||||
|
github.com/dlclark/regexp2 v1.11.4
|
||||||
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
||||||
github.com/google/go-cmp v0.6.0
|
github.com/google/go-cmp v0.6.0
|
||||||
github.com/mattn/go-runewidth v0.0.14
|
github.com/mattn/go-runewidth v0.0.14
|
||||||
github.com/nlpodyssey/gopickle v0.3.0
|
github.com/nlpodyssey/gopickle v0.3.0
|
||||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||||
golang.org/x/image v0.22.0
|
golang.org/x/image v0.22.0
|
||||||
|
golang.org/x/tools v0.30.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
@@ -68,12 +70,12 @@ require (
|
|||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/crypto v0.31.0
|
golang.org/x/crypto v0.33.0
|
||||||
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa
|
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
|
||||||
golang.org/x/net v0.25.0 // indirect
|
golang.org/x/net v0.35.0 // indirect
|
||||||
golang.org/x/sys v0.28.0
|
golang.org/x/sys v0.30.0
|
||||||
golang.org/x/term v0.27.0
|
golang.org/x/term v0.29.0
|
||||||
golang.org/x/text v0.21.0
|
golang.org/x/text v0.22.0
|
||||||
google.golang.org/protobuf v1.34.1
|
google.golang.org/protobuf v1.34.1
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
32
go.sum
32
go.sum
@@ -42,6 +42,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
|||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+UbP35JkH8yB7MYb4q/qhBarqZE6g=
|
github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+UbP35JkH8yB7MYb4q/qhBarqZE6g=
|
||||||
github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA=
|
github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA=
|
||||||
|
github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo=
|
||||||
|
github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||||
github.com/emirpasic/gods/v2 v2.0.0-alpha h1:dwFlh8pBg1VMOXWGipNMRt8v96dKAIvBehtCt6OtunU=
|
github.com/emirpasic/gods/v2 v2.0.0-alpha h1:dwFlh8pBg1VMOXWGipNMRt8v96dKAIvBehtCt6OtunU=
|
||||||
github.com/emirpasic/gods/v2 v2.0.0-alpha/go.mod h1:W0y4M2dtBB9U5z3YlghmpuUhiaZT2h6yoeE+C1sCp6A=
|
github.com/emirpasic/gods/v2 v2.0.0-alpha/go.mod h1:W0y4M2dtBB9U5z3YlghmpuUhiaZT2h6yoeE+C1sCp6A=
|
||||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||||
@@ -212,16 +214,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
|||||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
|
||||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
|
||||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
|
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
|
||||||
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ=
|
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4=
|
||||||
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE=
|
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
|
||||||
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
|
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
|
||||||
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
||||||
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||||
@@ -255,8 +257,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
|
|||||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||||
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
@@ -266,8 +268,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
|||||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
||||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
@@ -283,17 +285,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q=
|
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
|
||||||
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
|
golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
|
||||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
|
||||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
@@ -307,6 +309,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
|
|||||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||||
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||||
|
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
|
||||||
|
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
|
||||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
|||||||
412
integration/api_test.go
Normal file
412
integration/api_test.go
Normal file
@@ -0,0 +1,412 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAPIGenerate(t *testing.T) {
|
||||||
|
initialTimeout := 60 * time.Second
|
||||||
|
streamTimeout := 30 * time.Second
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
// Set up the test data
|
||||||
|
req := api.GenerateRequest{
|
||||||
|
Model: smol,
|
||||||
|
Prompt: "why is the sky blue? be brief",
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"temperature": 0,
|
||||||
|
"seed": 123,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
anyResp := []string{"rayleigh", "scattering"}
|
||||||
|
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
|
t.Fatalf("pull failed %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
stream bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "stream",
|
||||||
|
stream: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no_stream",
|
||||||
|
stream: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
stallTimer := time.NewTimer(initialTimeout)
|
||||||
|
var buf bytes.Buffer
|
||||||
|
fn := func(response api.GenerateResponse) error {
|
||||||
|
// Fields that must always be present
|
||||||
|
if response.Model == "" {
|
||||||
|
t.Errorf("response missing model: %#v", response)
|
||||||
|
}
|
||||||
|
if response.Done {
|
||||||
|
// Required fields for final updates:
|
||||||
|
if response.DoneReason == "" && *req.Stream {
|
||||||
|
// TODO - is the lack of done reason on non-stream a bug?
|
||||||
|
t.Errorf("final response missing done_reason: %#v", response)
|
||||||
|
}
|
||||||
|
if response.Metrics.TotalDuration == 0 {
|
||||||
|
t.Errorf("final response missing total_duration: %#v", response)
|
||||||
|
}
|
||||||
|
if response.Metrics.LoadDuration == 0 {
|
||||||
|
t.Errorf("final response missing load_duration: %#v", response)
|
||||||
|
}
|
||||||
|
if response.Metrics.PromptEvalDuration == 0 {
|
||||||
|
t.Errorf("final response missing prompt_eval_duration: %#v", response)
|
||||||
|
}
|
||||||
|
if response.Metrics.EvalCount == 0 {
|
||||||
|
t.Errorf("final response missing eval_count: %#v", response)
|
||||||
|
}
|
||||||
|
if response.Metrics.EvalDuration == 0 {
|
||||||
|
t.Errorf("final response missing eval_duration: %#v", response)
|
||||||
|
}
|
||||||
|
if len(response.Context) == 0 {
|
||||||
|
t.Errorf("final response missing context: %#v", response)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: caching can result in no prompt eval count, so this can't be verified reliably
|
||||||
|
// if response.Metrics.PromptEvalCount == 0 {
|
||||||
|
// t.Errorf("final response missing prompt_eval_count: %#v", response)
|
||||||
|
// }
|
||||||
|
|
||||||
|
} // else incremental response, nothing to check right now...
|
||||||
|
buf.Write([]byte(response.Response))
|
||||||
|
if !stallTimer.Reset(streamTimeout) {
|
||||||
|
return fmt.Errorf("stall was detected while streaming response, aborting")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan int)
|
||||||
|
var genErr error
|
||||||
|
go func() {
|
||||||
|
req.Stream = &test.stream
|
||||||
|
req.Options["seed"] = rand.Int() // bust cache for prompt eval results
|
||||||
|
genErr = client.Generate(ctx, &req, fn)
|
||||||
|
done <- 0
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-stallTimer.C:
|
||||||
|
if buf.Len() == 0 {
|
||||||
|
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
|
||||||
|
} else {
|
||||||
|
t.Errorf("generate stalled. Response so far:%s", buf.String())
|
||||||
|
}
|
||||||
|
case <-done:
|
||||||
|
if genErr != nil {
|
||||||
|
t.Fatalf("failed with %s request prompt %s ", req.Model, req.Prompt)
|
||||||
|
}
|
||||||
|
// Verify the response contains the expected data
|
||||||
|
response := buf.String()
|
||||||
|
atLeastOne := false
|
||||||
|
for _, resp := range anyResp {
|
||||||
|
if strings.Contains(strings.ToLower(response), resp) {
|
||||||
|
atLeastOne = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !atLeastOne {
|
||||||
|
t.Errorf("none of %v found in %s", anyResp, response)
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Error("outer test context done while waiting for generate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate PS while we're at it...
|
||||||
|
resp, err := client.ListRunning(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("list models API error: %s", err)
|
||||||
|
}
|
||||||
|
if resp == nil || len(resp.Models) == 0 {
|
||||||
|
t.Fatalf("list models API returned empty list while model should still be loaded")
|
||||||
|
}
|
||||||
|
// Find the model we just loaded and verify some attributes
|
||||||
|
found := false
|
||||||
|
for _, model := range resp.Models {
|
||||||
|
if strings.Contains(model.Name, req.Model) {
|
||||||
|
found = true
|
||||||
|
if model.Model == "" {
|
||||||
|
t.Errorf("model field omitted: %#v", model)
|
||||||
|
}
|
||||||
|
if model.Size == 0 {
|
||||||
|
t.Errorf("size omitted: %#v", model)
|
||||||
|
}
|
||||||
|
if model.Digest == "" {
|
||||||
|
t.Errorf("digest omitted: %#v", model)
|
||||||
|
}
|
||||||
|
verifyModelDetails(t, model.Details)
|
||||||
|
var nilTime time.Time
|
||||||
|
if model.ExpiresAt == nilTime {
|
||||||
|
t.Errorf("expires_at omitted: %#v", model)
|
||||||
|
}
|
||||||
|
// SizeVRAM could be zero.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Errorf("unable to locate running model: %#v", resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIChat(t *testing.T) {
|
||||||
|
initialTimeout := 60 * time.Second
|
||||||
|
streamTimeout := 30 * time.Second
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
// Set up the test data
|
||||||
|
req := api.ChatRequest{
|
||||||
|
Model: smol,
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "why is the sky blue? be brief",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"temperature": 0,
|
||||||
|
"seed": 123,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
anyResp := []string{"rayleigh", "scattering"}
|
||||||
|
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
|
t.Fatalf("pull failed %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
stream bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "stream",
|
||||||
|
stream: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no_stream",
|
||||||
|
stream: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
stallTimer := time.NewTimer(initialTimeout)
|
||||||
|
var buf bytes.Buffer
|
||||||
|
fn := func(response api.ChatResponse) error {
|
||||||
|
// Fields that must always be present
|
||||||
|
if response.Model == "" {
|
||||||
|
t.Errorf("response missing model: %#v", response)
|
||||||
|
}
|
||||||
|
if response.Done {
|
||||||
|
// Required fields for final updates:
|
||||||
|
var nilTime time.Time
|
||||||
|
if response.CreatedAt == nilTime {
|
||||||
|
t.Errorf("final response missing total_duration: %#v", response)
|
||||||
|
}
|
||||||
|
if response.DoneReason == "" {
|
||||||
|
t.Errorf("final response missing done_reason: %#v", response)
|
||||||
|
}
|
||||||
|
if response.Metrics.TotalDuration == 0 {
|
||||||
|
t.Errorf("final response missing total_duration: %#v", response)
|
||||||
|
}
|
||||||
|
if response.Metrics.LoadDuration == 0 {
|
||||||
|
t.Errorf("final response missing load_duration: %#v", response)
|
||||||
|
}
|
||||||
|
if response.Metrics.PromptEvalDuration == 0 {
|
||||||
|
t.Errorf("final response missing prompt_eval_duration: %#v", response)
|
||||||
|
}
|
||||||
|
if response.Metrics.EvalCount == 0 {
|
||||||
|
t.Errorf("final response missing eval_count: %#v", response)
|
||||||
|
}
|
||||||
|
if response.Metrics.EvalDuration == 0 {
|
||||||
|
t.Errorf("final response missing eval_duration: %#v", response)
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.Metrics.PromptEvalCount == 0 {
|
||||||
|
t.Errorf("final response missing prompt_eval_count: %#v", response)
|
||||||
|
}
|
||||||
|
} // else incremental response, nothing to check right now...
|
||||||
|
buf.Write([]byte(response.Message.Content))
|
||||||
|
if !stallTimer.Reset(streamTimeout) {
|
||||||
|
return fmt.Errorf("stall was detected while streaming response, aborting")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan int)
|
||||||
|
var genErr error
|
||||||
|
go func() {
|
||||||
|
req.Stream = &test.stream
|
||||||
|
req.Options["seed"] = rand.Int() // bust cache for prompt eval results
|
||||||
|
genErr = client.Chat(ctx, &req, fn)
|
||||||
|
done <- 0
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-stallTimer.C:
|
||||||
|
if buf.Len() == 0 {
|
||||||
|
t.Errorf("chat never started. Timed out after :%s", initialTimeout.String())
|
||||||
|
} else {
|
||||||
|
t.Errorf("chat stalled. Response so far:%s", buf.String())
|
||||||
|
}
|
||||||
|
case <-done:
|
||||||
|
if genErr != nil {
|
||||||
|
t.Fatalf("failed with %s request prompt %v", req.Model, req.Messages)
|
||||||
|
}
|
||||||
|
// Verify the response contains the expected data
|
||||||
|
response := buf.String()
|
||||||
|
atLeastOne := false
|
||||||
|
for _, resp := range anyResp {
|
||||||
|
if strings.Contains(strings.ToLower(response), resp) {
|
||||||
|
atLeastOne = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !atLeastOne {
|
||||||
|
t.Errorf("none of %v found in %s", anyResp, response)
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Error("outer test context done while waiting for chat")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIListModels(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Make sure we have at least one model so an empty list can be considered a failure
|
||||||
|
if err := PullIfMissing(ctx, client, smol); err != nil {
|
||||||
|
t.Fatalf("pull failed %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.List(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to list models: %s", err)
|
||||||
|
}
|
||||||
|
if len(resp.Models) == 0 {
|
||||||
|
t.Fatalf("list should not be empty")
|
||||||
|
}
|
||||||
|
model := resp.Models[0]
|
||||||
|
if model.Name == "" {
|
||||||
|
t.Errorf("first model name empty: %#v", model)
|
||||||
|
}
|
||||||
|
var nilTime time.Time
|
||||||
|
if model.ModifiedAt == nilTime {
|
||||||
|
t.Errorf("first model modified_at empty: %#v", model)
|
||||||
|
}
|
||||||
|
if model.Size == 0 {
|
||||||
|
t.Errorf("first model size empty: %#v", model)
|
||||||
|
}
|
||||||
|
if model.Digest == "" {
|
||||||
|
t.Errorf("first model digest empty: %#v", model)
|
||||||
|
}
|
||||||
|
verifyModelDetails(t, model.Details)
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyModelDetails(t *testing.T, details api.ModelDetails) {
|
||||||
|
if details.Format == "" {
|
||||||
|
t.Errorf("first model details.format empty: %#v", details)
|
||||||
|
}
|
||||||
|
if details.Family == "" {
|
||||||
|
t.Errorf("first model details.family empty: %#v", details)
|
||||||
|
}
|
||||||
|
if details.ParameterSize == "" {
|
||||||
|
t.Errorf("first model details.parameter_size empty: %#v", details)
|
||||||
|
}
|
||||||
|
if details.QuantizationLevel == "" {
|
||||||
|
t.Errorf("first model details.quantization_level empty: %#v", details)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIShowModel(t *testing.T) {
|
||||||
|
modelName := "llama3.2"
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
if err := PullIfMissing(ctx, client, modelName); err != nil {
|
||||||
|
t.Fatalf("pull failed %s", err)
|
||||||
|
}
|
||||||
|
resp, err := client.Show(ctx, &api.ShowRequest{Name: modelName})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to show model: %s", err)
|
||||||
|
}
|
||||||
|
if resp.License == "" {
|
||||||
|
t.Errorf("%s missing license: %#v", modelName, resp)
|
||||||
|
}
|
||||||
|
if resp.Modelfile == "" {
|
||||||
|
t.Errorf("%s missing modelfile: %#v", modelName, resp)
|
||||||
|
}
|
||||||
|
if resp.Parameters == "" {
|
||||||
|
t.Errorf("%s missing parameters: %#v", modelName, resp)
|
||||||
|
}
|
||||||
|
if resp.Template == "" {
|
||||||
|
t.Errorf("%s missing template: %#v", modelName, resp)
|
||||||
|
}
|
||||||
|
// llama3 omits system
|
||||||
|
verifyModelDetails(t, resp.Details)
|
||||||
|
// llama3 ommits messages
|
||||||
|
if len(resp.ModelInfo) == 0 {
|
||||||
|
t.Errorf("%s missing model_info: %#v", modelName, resp)
|
||||||
|
}
|
||||||
|
// llama3 omits projectors
|
||||||
|
var nilTime time.Time
|
||||||
|
if resp.ModifiedAt == nilTime {
|
||||||
|
t.Errorf("%s missing modified_at: %#v", modelName, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIEmbeddings(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
req := api.EmbeddingRequest{
|
||||||
|
Model: "orca-mini",
|
||||||
|
Prompt: "why is the sky blue?",
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"temperature": 0,
|
||||||
|
"seed": 123,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
|
t.Fatalf("pull failed %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Embeddings(ctx, &req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("embeddings call failed %s", err)
|
||||||
|
}
|
||||||
|
if len(resp.Embedding) == 0 {
|
||||||
|
t.Errorf("zero length embedding response")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -14,15 +14,15 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestOrcaMiniBlueSky(t *testing.T) {
|
func TestBlueSky(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
// Set up the test data
|
// Set up the test data
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
Model: "orca-mini",
|
Model: smol,
|
||||||
Prompt: "why is the sky blue?",
|
Prompt: "why is the sky blue?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"seed": 123,
|
"seed": 123,
|
||||||
},
|
},
|
||||||
@@ -31,6 +31,7 @@ func TestOrcaMiniBlueSky(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnicode(t *testing.T) {
|
func TestUnicode(t *testing.T) {
|
||||||
|
skipUnderMinVRAM(t, 6)
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
// Set up the test data
|
// Set up the test data
|
||||||
@@ -39,7 +40,7 @@ func TestUnicode(t *testing.T) {
|
|||||||
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K",
|
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K",
|
||||||
Prompt: "天空为什么是蓝色的?",
|
Prompt: "天空为什么是蓝色的?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"seed": 123,
|
"seed": 123,
|
||||||
// Workaround deepseek context shifting bug
|
// Workaround deepseek context shifting bug
|
||||||
@@ -61,7 +62,7 @@ func TestExtendedUnicodeOutput(t *testing.T) {
|
|||||||
Model: "gemma2:2b",
|
Model: "gemma2:2b",
|
||||||
Prompt: "Output some smily face emoji",
|
Prompt: "Output some smily face emoji",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"seed": 123,
|
"seed": 123,
|
||||||
},
|
},
|
||||||
@@ -93,10 +94,10 @@ func TestUnicodeModelDir(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
Model: "orca-mini",
|
Model: smol,
|
||||||
Prompt: "why is the sky blue?",
|
Prompt: "why is the sky blue?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"seed": 123,
|
"seed": 123,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -21,11 +21,11 @@ func TestMultiModelConcurrency(t *testing.T) {
|
|||||||
var (
|
var (
|
||||||
req = [2]api.GenerateRequest{
|
req = [2]api.GenerateRequest{
|
||||||
{
|
{
|
||||||
Model: "orca-mini",
|
Model: "llama3.2:1b",
|
||||||
Prompt: "why is the ocean blue?",
|
Prompt: "why is the ocean blue?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
@@ -34,7 +34,7 @@ func TestMultiModelConcurrency(t *testing.T) {
|
|||||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
@@ -67,7 +67,7 @@ func TestMultiModelConcurrency(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
|
func TestIntegrationConcurrentPredict(t *testing.T) {
|
||||||
req, resp := GenerateRequests()
|
req, resp := GenerateRequests()
|
||||||
reqLimit := len(req)
|
reqLimit := len(req)
|
||||||
iterLimit := 5
|
iterLimit := 5
|
||||||
@@ -117,6 +117,9 @@ func TestMultiModelStress(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
if maxVram < 2*format.GibiByte {
|
||||||
|
t.Skip("VRAM less than 2G, skipping model stress tests")
|
||||||
|
}
|
||||||
|
|
||||||
type model struct {
|
type model struct {
|
||||||
name string
|
name string
|
||||||
@@ -125,8 +128,8 @@ func TestMultiModelStress(t *testing.T) {
|
|||||||
|
|
||||||
smallModels := []model{
|
smallModels := []model{
|
||||||
{
|
{
|
||||||
name: "orca-mini",
|
name: "llama3.2:1b",
|
||||||
size: 2992 * format.MebiByte,
|
size: 2876 * format.MebiByte,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "phi",
|
name: "phi",
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ func TestLongInputContext(t *testing.T) {
|
|||||||
Model: "llama2",
|
Model: "llama2",
|
||||||
Prompt: "Oh, don’t speak to me of Austria. Perhaps I don’t understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexander’s loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I don’t believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
|
Prompt: "Oh, don’t speak to me of Austria. Perhaps I don’t understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexander’s loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I don’t believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"seed": 123,
|
"seed": 123,
|
||||||
"num_ctx": 128,
|
"num_ctx": 128,
|
||||||
@@ -50,7 +50,7 @@ func TestContextExhaustion(t *testing.T) {
|
|||||||
Model: "llama2",
|
Model: "llama2",
|
||||||
Prompt: "Write me a story with a ton of emojis?",
|
Prompt: "Write me a story with a ton of emojis?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"seed": 123,
|
"seed": 123,
|
||||||
"num_ctx": 128,
|
"num_ctx": 128,
|
||||||
|
|||||||
@@ -12,14 +12,63 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestIntegrationLlava(t *testing.T) {
|
func TestVisionModels(t *testing.T) {
|
||||||
|
skipUnderMinVRAM(t, 6)
|
||||||
|
type testCase struct {
|
||||||
|
model string
|
||||||
|
}
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
model: "llava:7b",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
model: "llama3.2-vision",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
model: "gemma3",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range testCases {
|
||||||
|
t.Run(v.model, func(t *testing.T) {
|
||||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
Model: "llava:7b",
|
Model: v.model,
|
||||||
Prompt: "what does the text in this image say?",
|
Prompt: "what does the text in this image say?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 0.0,
|
||||||
|
},
|
||||||
|
Images: []api.ImageData{
|
||||||
|
image,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
|
||||||
|
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
|
||||||
|
resp := "the ollam"
|
||||||
|
defer cleanup()
|
||||||
|
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||||
|
// llava models on CPU can be quite slow to start
|
||||||
|
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegrationSplitBatch(t *testing.T) {
|
||||||
|
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||||
|
require.NoError(t, err)
|
||||||
|
req := api.GenerateRequest{
|
||||||
|
Model: "gemma3:4b",
|
||||||
|
// Fill up a chunk of the batch so the image will partially spill over into the next one
|
||||||
|
System: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed aliquet, justo in malesuada lobortis, odio ligula volutpat quam, quis faucibus ipsum magna quis sapien. Aliquam in venenatis diam, eu viverra magna. Phasellus imperdiet hendrerit volutpat. Vivamus sem ex, facilisis placerat felis non, dictum elementum est. Phasellus aliquam imperdiet lacus, eget placerat ligula sodales vel. Pellentesque nec auctor mi. Curabitur arcu nisi, faucibus eget nunc id, viverra interdum mi. Curabitur ornare ipsum ex, ac euismod ex aliquam in. Vestibulum id magna at purus accumsan fermentum. Proin scelerisque posuere nunc quis interdum. Maecenas sed mollis nisl. Etiam vitae ipsum interdum, placerat est quis, tincidunt velit. Nullam tempor nibh non lorem volutpat efficitur. Cras laoreet diam imperdiet ipsum auctor bibendum. Suspendisse ultrices urna sed metus sagittis suscipit. Quisque ullamcorper aliquam nibh ut mollis. Aenean dapibus mauris pharetra, venenatis elit ac, hendrerit odio. Cras vestibulum erat tempor, lobortis justo eu, lobortis ipsum. Nam laoreet dapibus sem. Proin vel diam ultrices, elementum ante et, ornare lectus. Proin eu accumsan nisl. Praesent ac ex vitae ipsum vulputate tristique facilisis sit amet lacus. Nullam faucibus magna a pellentesque pretium. Nunc lacinia ullamcorper sollicitudin. Donec vitae accumsan turpis, sed porttitor est. Donec porttitor mi vitae augue faucibus, vel mollis diam tincidunt.",
|
||||||
|
Prompt: "what does the text in this image say?",
|
||||||
|
Stream: &stream,
|
||||||
|
Options: map[string]any{
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
@@ -39,33 +88,6 @@ func TestIntegrationLlava(t *testing.T) {
|
|||||||
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
|
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIntegrationMllama(t *testing.T) {
|
|
||||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
|
||||||
require.NoError(t, err)
|
|
||||||
req := api.GenerateRequest{
|
|
||||||
// TODO fix up once we publish the final image
|
|
||||||
Model: "x/llama3.2-vision",
|
|
||||||
Prompt: "what does the text in this image say?",
|
|
||||||
Stream: &stream,
|
|
||||||
Options: map[string]interface{}{
|
|
||||||
"seed": 42,
|
|
||||||
"temperature": 0.0,
|
|
||||||
},
|
|
||||||
Images: []api.ImageData{
|
|
||||||
image,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := "the ollamas"
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
|
||||||
defer cancel()
|
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
|
||||||
defer cleanup()
|
|
||||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
|
||||||
// mllama models on CPU can be quite slow to start,
|
|
||||||
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
|
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
|
||||||
AAUAAAABAAAAUgEoAAMAAAABAAIAAIdpAAQAAAABAAAAWgAAAAAAAABIAAAAAQAAAEgAAAABAAOgAQADAAAAAQABAACgAgAEAAAAAQAAANKgAwAEAAAAAQAA
|
AAUAAAABAAAAUgEoAAMAAAABAAIAAIdpAAQAAAABAAAAWgAAAAAAAABIAAAAAQAAAEgAAAABAAOgAQADAAAAAQABAACgAgAEAAAAAQAAANKgAwAEAAAAAQAA
|
||||||
AHgAAAAAXdsepgAAAAlwSFlzAAALEwAACxMBAJqcGAAAAVlpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6
|
AHgAAAAAXdsepgAAAAlwSFlzAAALEwAACxMBAJqcGAAAAVlpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6
|
||||||
|
|||||||
@@ -17,30 +17,30 @@ var (
|
|||||||
stream = false
|
stream = false
|
||||||
req = [2]api.GenerateRequest{
|
req = [2]api.GenerateRequest{
|
||||||
{
|
{
|
||||||
Model: "orca-mini",
|
Model: smol,
|
||||||
Prompt: "why is the ocean blue?",
|
Prompt: "why is the ocean blue?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
}, {
|
}, {
|
||||||
Model: "orca-mini",
|
Model: smol,
|
||||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
resp = [2][]string{
|
resp = [2][]string{
|
||||||
{"sunlight"},
|
{"sunlight", "scattering", "interact"},
|
||||||
{"england", "english", "massachusetts", "pilgrims"},
|
{"england", "english", "massachusetts", "pilgrims"},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestIntegrationSimpleOrcaMini(t *testing.T) {
|
func TestIntegrationSimple(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
GenerateTestHelper(ctx, t, req[0], resp[0])
|
GenerateTestHelper(ctx, t, req[0], resp[0])
|
||||||
|
|||||||
@@ -30,9 +30,9 @@ func TestMaxQueue(t *testing.T) {
|
|||||||
t.Setenv("OLLAMA_MAX_QUEUE", strconv.Itoa(threadCount))
|
t.Setenv("OLLAMA_MAX_QUEUE", strconv.Itoa(threadCount))
|
||||||
|
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
Model: "orca-mini",
|
Model: smol,
|
||||||
Prompt: "write a long historical fiction story about christopher columbus. use at least 10 facts from his actual journey",
|
Prompt: "write a long historical fiction story about christopher columbus. use at least 10 facts from his actual journey",
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
@@ -52,8 +52,8 @@ func TestMaxQueue(t *testing.T) {
|
|||||||
embedCtx := ctx
|
embedCtx := ctx
|
||||||
|
|
||||||
var genwg sync.WaitGroup
|
var genwg sync.WaitGroup
|
||||||
go func() {
|
|
||||||
genwg.Add(1)
|
genwg.Add(1)
|
||||||
|
go func() {
|
||||||
defer genwg.Done()
|
defer genwg.Done()
|
||||||
slog.Info("Starting generate request")
|
slog.Info("Starting generate request")
|
||||||
DoGenerate(ctx, t, client, req, resp, 45*time.Second, 5*time.Second)
|
DoGenerate(ctx, t, client, req, resp, 45*time.Second, 5*time.Second)
|
||||||
@@ -61,7 +61,7 @@ func TestMaxQueue(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// Give the generate a chance to get started before we start hammering on embed requests
|
// Give the generate a chance to get started before we start hammering on embed requests
|
||||||
time.Sleep(5 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
threadCount += 10 // Add a few extra to ensure we push the queue past its limit
|
threadCount += 10 // Add a few extra to ensure we push the queue past its limit
|
||||||
busyCount := 0
|
busyCount := 0
|
||||||
@@ -71,8 +71,8 @@ func TestMaxQueue(t *testing.T) {
|
|||||||
counterMu := sync.Mutex{}
|
counterMu := sync.Mutex{}
|
||||||
var embedwg sync.WaitGroup
|
var embedwg sync.WaitGroup
|
||||||
for i := 0; i < threadCount; i++ {
|
for i := 0; i < threadCount; i++ {
|
||||||
go func(i int) {
|
|
||||||
embedwg.Add(1)
|
embedwg.Add(1)
|
||||||
|
go func(i int) {
|
||||||
defer embedwg.Done()
|
defer embedwg.Done()
|
||||||
slog.Info("embed started", "id", i)
|
slog.Info("embed started", "id", i)
|
||||||
embedReq := api.EmbeddingRequest{
|
embedReq := api.EmbeddingRequest{
|
||||||
|
|||||||
195
integration/model_arch_test.go
Normal file
195
integration/model_arch_test.go
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
//go:build integration && models
|
||||||
|
|
||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/format"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
started = time.Now()
|
||||||
|
chatModels = []string{
|
||||||
|
"granite3-moe:latest",
|
||||||
|
"granite-code:latest",
|
||||||
|
"nemotron-mini:latest",
|
||||||
|
"command-r:latest",
|
||||||
|
"gemma2:latest",
|
||||||
|
"gemma:latest",
|
||||||
|
"internlm2:latest",
|
||||||
|
"phi3.5:latest",
|
||||||
|
"phi3:latest",
|
||||||
|
// "phi:latest", // flaky, sometimes generates no response on first query
|
||||||
|
"stablelm2:latest", // Predictions are off, crashes on small VRAM GPUs
|
||||||
|
"falcon:latest",
|
||||||
|
"falcon2:latest",
|
||||||
|
"minicpm-v:latest",
|
||||||
|
"mistral:latest",
|
||||||
|
"orca-mini:latest",
|
||||||
|
"llama2:latest",
|
||||||
|
"llama3.1:latest",
|
||||||
|
"llama3.2:latest",
|
||||||
|
"llama3.2-vision:latest",
|
||||||
|
"qwen2.5-coder:latest",
|
||||||
|
"qwen:latest",
|
||||||
|
"solar-pro:latest",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) {
|
||||||
|
deadline, hasDeadline := t.Deadline()
|
||||||
|
if !hasDeadline {
|
||||||
|
return 8 * time.Minute, 10 * time.Minute
|
||||||
|
} else if deadline.Compare(time.Now().Add(2*time.Minute)) <= 0 {
|
||||||
|
t.Skip("too little time")
|
||||||
|
return time.Duration(0), time.Duration(0)
|
||||||
|
}
|
||||||
|
return -time.Since(deadline.Add(-2 * time.Minute)), -time.Since(deadline.Add(-20 * time.Second))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelsGenerate(t *testing.T) {
|
||||||
|
softTimeout, hardTimeout := getTimeouts(t)
|
||||||
|
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// TODO use info API eventually
|
||||||
|
var maxVram uint64
|
||||||
|
var err error
|
||||||
|
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||||
|
maxVram, err = strconv.ParseUint(s, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("invalid OLLAMA_MAX_VRAM %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, model := range chatModels {
|
||||||
|
t.Run(model, func(t *testing.T) {
|
||||||
|
if time.Now().Sub(started) > softTimeout {
|
||||||
|
t.Skip("skipping remaining tests to avoid excessive runtime")
|
||||||
|
}
|
||||||
|
if err := PullIfMissing(ctx, client, model); err != nil {
|
||||||
|
t.Fatalf("pull failed %s", err)
|
||||||
|
}
|
||||||
|
if maxVram > 0 {
|
||||||
|
resp, err := client.List(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("list models failed %v", err)
|
||||||
|
}
|
||||||
|
for _, m := range resp.Models {
|
||||||
|
if m.Name == model && float32(m.Size)*1.2 > float32(maxVram) {
|
||||||
|
t.Skipf("model %s is too large for available VRAM: %s > %s", model, format.HumanBytes(m.Size), format.HumanBytes(int64(maxVram)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// TODO - fiddle with context size
|
||||||
|
req := api.GenerateRequest{
|
||||||
|
Model: model,
|
||||||
|
Prompt: "why is the sky blue?",
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"temperature": 0,
|
||||||
|
"seed": 123,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
anyResp := []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}
|
||||||
|
DoGenerate(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelsEmbed(t *testing.T) {
|
||||||
|
softTimeout, hardTimeout := getTimeouts(t)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// TODO use info API eventually
|
||||||
|
var maxVram uint64
|
||||||
|
var err error
|
||||||
|
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||||
|
maxVram, err = strconv.ParseUint(s, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("invalid OLLAMA_MAX_VRAM %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := ioutil.ReadFile(filepath.Join("testdata", "embed.json"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to open test data file: %s", err)
|
||||||
|
}
|
||||||
|
testCase := map[string][]float64{}
|
||||||
|
err = json.Unmarshal(data, &testCase)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to load test data: %s", err)
|
||||||
|
}
|
||||||
|
for model, expected := range testCase {
|
||||||
|
|
||||||
|
t.Run(model, func(t *testing.T) {
|
||||||
|
if time.Now().Sub(started) > softTimeout {
|
||||||
|
t.Skip("skipping remaining tests to avoid excessive runtime")
|
||||||
|
}
|
||||||
|
if err := PullIfMissing(ctx, client, model); err != nil {
|
||||||
|
t.Fatalf("pull failed %s", err)
|
||||||
|
}
|
||||||
|
if maxVram > 0 {
|
||||||
|
resp, err := client.List(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("list models failed %v", err)
|
||||||
|
}
|
||||||
|
for _, m := range resp.Models {
|
||||||
|
if m.Name == model && float32(m.Size)*1.2 > float32(maxVram) {
|
||||||
|
t.Skipf("model %s is too large for available VRAM: %s > %s", model, format.HumanBytes(m.Size), format.HumanBytes(int64(maxVram)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
req := api.EmbeddingRequest{
|
||||||
|
Model: model,
|
||||||
|
Prompt: "why is the sky blue?",
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"temperature": 0,
|
||||||
|
"seed": 123,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := client.Embeddings(ctx, &req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("embeddings call failed %s", err)
|
||||||
|
}
|
||||||
|
if len(resp.Embedding) == 0 {
|
||||||
|
t.Errorf("zero length embedding response")
|
||||||
|
}
|
||||||
|
if len(expected) != len(resp.Embedding) {
|
||||||
|
expStr := make([]string, len(resp.Embedding))
|
||||||
|
for i, v := range resp.Embedding {
|
||||||
|
expStr[i] = fmt.Sprintf("%0.6f", v)
|
||||||
|
}
|
||||||
|
// When adding new models, use this output to populate the testdata/embed.json
|
||||||
|
fmt.Printf("expected\n%s\n", strings.Join(expStr, ", "))
|
||||||
|
t.Fatalf("expected %d, got %d", len(expected), len(resp.Embedding))
|
||||||
|
}
|
||||||
|
sim := cosineSimilarity(resp.Embedding, expected)
|
||||||
|
if sim < 0.99 {
|
||||||
|
t.Fatalf("expected %v, got %v (similarity: %f)", expected[0:5], resp.Embedding[0:5], sim)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
21
integration/testdata/embed.json
vendored
Normal file
21
integration/testdata/embed.json
vendored
Normal file
File diff suppressed because one or more lines are too long
@@ -24,9 +24,14 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/app/lifecycle"
|
"github.com/ollama/ollama/app/lifecycle"
|
||||||
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
smol = "llama3.2:1b"
|
||||||
|
)
|
||||||
|
|
||||||
func Init() {
|
func Init() {
|
||||||
lifecycle.InitLogging()
|
lifecycle.InitLogging()
|
||||||
}
|
}
|
||||||
@@ -140,7 +145,7 @@ func PullIfMissing(ctx context.Context, client *api.Client, modelName string) er
|
|||||||
|
|
||||||
showCtx, cancel := context.WithDeadlineCause(
|
showCtx, cancel := context.WithDeadlineCause(
|
||||||
ctx,
|
ctx,
|
||||||
time.Now().Add(10*time.Second),
|
time.Now().Add(20*time.Second),
|
||||||
fmt.Errorf("show for existing model %s took too long", modelName),
|
fmt.Errorf("show for existing model %s took too long", modelName),
|
||||||
)
|
)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -157,7 +162,7 @@ func PullIfMissing(ctx context.Context, client *api.Client, modelName string) er
|
|||||||
}
|
}
|
||||||
slog.Info("model missing", "model", modelName)
|
slog.Info("model missing", "model", modelName)
|
||||||
|
|
||||||
stallDuration := 30 * time.Second // This includes checksum verification, which can take a while on larger models
|
stallDuration := 60 * time.Second // This includes checksum verification, which can take a while on larger models, and slower systems
|
||||||
stallTimer := time.NewTimer(stallDuration)
|
stallTimer := time.NewTimer(stallDuration)
|
||||||
fn := func(resp api.ProgressResponse) error {
|
fn := func(resp api.ProgressResponse) error {
|
||||||
// fmt.Print(".")
|
// fmt.Print(".")
|
||||||
@@ -283,51 +288,51 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Generate a set of requests
|
// Generate a set of requests
|
||||||
// By default each request uses orca-mini as the model
|
// By default each request uses llama3.2 as the model
|
||||||
func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||||
return []api.GenerateRequest{
|
return []api.GenerateRequest{
|
||||||
{
|
{
|
||||||
Model: "orca-mini",
|
Model: smol,
|
||||||
Prompt: "why is the ocean blue?",
|
Prompt: "why is the ocean blue?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
}, {
|
}, {
|
||||||
Model: "orca-mini",
|
Model: smol,
|
||||||
Prompt: "why is the color of dirt brown?",
|
Prompt: "why is the color of dirt brown?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
}, {
|
}, {
|
||||||
Model: "orca-mini",
|
Model: smol,
|
||||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
}, {
|
}, {
|
||||||
Model: "orca-mini",
|
Model: smol,
|
||||||
Prompt: "what is the origin of independence day?",
|
Prompt: "what is the origin of independence day?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
}, {
|
}, {
|
||||||
Model: "orca-mini",
|
Model: smol,
|
||||||
Prompt: "what is the composition of air?",
|
Prompt: "what is the composition of air?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
@@ -341,3 +346,15 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
|||||||
{"nitrogen", "oxygen", "carbon", "dioxide"},
|
{"nitrogen", "oxygen", "carbon", "dioxide"},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
||||||
|
// TODO use info API in the future
|
||||||
|
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||||
|
maxVram, err := strconv.ParseUint(s, 10, 64)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// Don't hammer on small VRAM cards...
|
||||||
|
if maxVram < gb*format.GibiByte {
|
||||||
|
t.Skip("skipping with small VRAM to avoid timeouts")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
77
kvcache/cache.go
Normal file
77
kvcache/cache.go
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
package kvcache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrKvCacheFull = errors.New("could not find a kv cache slot")
|
||||||
|
ErrNotSupported = errors.New("model does not support operation")
|
||||||
|
)
|
||||||
|
|
||||||
|
type Cache interface {
|
||||||
|
// ** used by model implementations **
|
||||||
|
|
||||||
|
// SetLayer sets the active layer of the cache
|
||||||
|
SetLayer(layer int)
|
||||||
|
|
||||||
|
// Get returns the history of key and value tensors plus a mask
|
||||||
|
//
|
||||||
|
// The shape of the tensors is documented in the specific
|
||||||
|
// cache implementation used.
|
||||||
|
Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
|
||||||
|
|
||||||
|
// Put stores a batch of key and value in the cache
|
||||||
|
//
|
||||||
|
// The shape of the tensors is documented in the specific
|
||||||
|
// cache implementation used.
|
||||||
|
Put(ctx ml.Context, key, value ml.Tensor)
|
||||||
|
|
||||||
|
// SetConfig controls optimizations (mostly backend-specific) that may transform
|
||||||
|
// the output of the cache to work better with specific kernels. If not called,
|
||||||
|
// the backend settings will be used. This works well when calling Attention.
|
||||||
|
//
|
||||||
|
// The config can be overridden by models, especially if they require vanilla
|
||||||
|
// output when implementing their own version of attention. To do this, pass
|
||||||
|
// an empty ml.CacheConfig.
|
||||||
|
//
|
||||||
|
// Most models will not need to use this.
|
||||||
|
SetConfig(ml.CacheConfig)
|
||||||
|
|
||||||
|
// ** cache management **
|
||||||
|
|
||||||
|
// Init sets up runtime parameters.
|
||||||
|
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
|
||||||
|
// dtype: The data type for storing cache entries
|
||||||
|
// maxSequences: The maximum number of sequences stored in the cache - across all batches
|
||||||
|
// capacity: The number of cache entries to store, per sequence
|
||||||
|
// maxBatch: The maximum number of tokens that can occur in a single batch
|
||||||
|
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
|
||||||
|
|
||||||
|
// Close closes the cache and frees resources associated with it
|
||||||
|
Close()
|
||||||
|
|
||||||
|
// StartForward is called before the start of the model's forward pass.
|
||||||
|
// For each token in the coming batch, there must be a corresponding
|
||||||
|
// entry in positions and seqs. reserve is to preallocate memory
|
||||||
|
// without actually storing data in the cache.
|
||||||
|
StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
|
||||||
|
|
||||||
|
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
|
||||||
|
CopyPrefix(srcSeq, dstSeq int, len int32)
|
||||||
|
|
||||||
|
// CanResume returns true if the cache can continue with the next token at
|
||||||
|
// the given position and sequence. Assumes that the caller has already
|
||||||
|
// verified the contents of the cache.
|
||||||
|
CanResume(seq int, pos int32) bool
|
||||||
|
|
||||||
|
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
|
||||||
|
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
|
||||||
|
//
|
||||||
|
// If an error occurs, the entire context for the sequence should be
|
||||||
|
// removed by calling Remove(seq, 0, math.MaxInt32)
|
||||||
|
Remove(seq int, beginIndex, endIndex int32) error
|
||||||
|
}
|
||||||
739
kvcache/causal.go
Normal file
739
kvcache/causal.go
Normal file
@@ -0,0 +1,739 @@
|
|||||||
|
package kvcache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"math"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
|
||||||
|
|
||||||
|
// Causal cache stores K and V tensors according to their position in the
|
||||||
|
// sequence. Returns the history and a mask for attending to past tokens
|
||||||
|
//
|
||||||
|
// The tensors are of shape embed dim, kv heads, batch size
|
||||||
|
// The mask is of shape history size, batch size
|
||||||
|
type Causal struct {
|
||||||
|
DType ml.DType
|
||||||
|
windowSize int32
|
||||||
|
chunkSize int32
|
||||||
|
|
||||||
|
opts CausalOptions
|
||||||
|
|
||||||
|
// config controls mostly backend-specific optimizations
|
||||||
|
config *ml.CacheConfig
|
||||||
|
|
||||||
|
// ** current forward pass **
|
||||||
|
|
||||||
|
// the active layer for Get and Put
|
||||||
|
curLayer int
|
||||||
|
|
||||||
|
// starting location for data storage for this batch
|
||||||
|
curLoc int
|
||||||
|
|
||||||
|
// size of the current batch
|
||||||
|
curBatchSize int
|
||||||
|
|
||||||
|
// mask of the cache as used by this batch
|
||||||
|
curMask ml.Tensor
|
||||||
|
|
||||||
|
// locations in the cache that are needed for this batch
|
||||||
|
curCellRange cellRange
|
||||||
|
|
||||||
|
// curSequences is the sequences corresponding to this pass's entries in the cache
|
||||||
|
curSequences []int
|
||||||
|
|
||||||
|
// curPositions is the positions corresponding to this pass's entries in the cache
|
||||||
|
curPositions []int32
|
||||||
|
|
||||||
|
// ** cache metadata **
|
||||||
|
|
||||||
|
// for each possible location in the cache, stores the position and set of sequences
|
||||||
|
// that reference the data there
|
||||||
|
cells []cacheCell
|
||||||
|
|
||||||
|
// maps from sequence to the range of locations where it is stored in the cache
|
||||||
|
cellRanges map[int]cellRange
|
||||||
|
|
||||||
|
// ** cache data storage **
|
||||||
|
|
||||||
|
shiftFn shiftFn
|
||||||
|
backend ml.Backend
|
||||||
|
ctxs map[int]ml.Context
|
||||||
|
keys, values map[int]ml.Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
type cacheCell struct {
|
||||||
|
pos int32
|
||||||
|
sequences []int
|
||||||
|
}
|
||||||
|
|
||||||
|
type cellRange struct {
|
||||||
|
min int
|
||||||
|
max int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCausalCache(shift shiftFn) *Causal {
|
||||||
|
return &Causal{
|
||||||
|
windowSize: math.MaxInt32,
|
||||||
|
shiftFn: shift,
|
||||||
|
ctxs: make(map[int]ml.Context),
|
||||||
|
keys: make(map[int]ml.Tensor),
|
||||||
|
values: make(map[int]ml.Tensor),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
||||||
|
return &Causal{
|
||||||
|
windowSize: windowSize,
|
||||||
|
shiftFn: shift,
|
||||||
|
ctxs: make(map[int]ml.Context),
|
||||||
|
keys: make(map[int]ml.Tensor),
|
||||||
|
values: make(map[int]ml.Tensor),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
|
||||||
|
return &Causal{
|
||||||
|
windowSize: math.MaxInt32,
|
||||||
|
chunkSize: chunkSize,
|
||||||
|
shiftFn: shift,
|
||||||
|
ctxs: make(map[int]ml.Context),
|
||||||
|
keys: make(map[int]ml.Tensor),
|
||||||
|
values: make(map[int]ml.Tensor),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||||
|
if c.config == nil {
|
||||||
|
var config ml.CacheConfig
|
||||||
|
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||||
|
config = cc.CacheConfig()
|
||||||
|
}
|
||||||
|
c.config = &config
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.config.CachePadding == 0 {
|
||||||
|
c.config.CachePadding = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.config.MaskBatchPadding == 0 {
|
||||||
|
c.config.MaskBatchPadding = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.config.MaskDType == ml.DTypeOther {
|
||||||
|
c.config.MaskDType = ml.DTypeF32
|
||||||
|
}
|
||||||
|
|
||||||
|
var cacheSize int
|
||||||
|
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize) {
|
||||||
|
cacheSize = maxSequences * capacity
|
||||||
|
} else {
|
||||||
|
cacheSize = (maxSequences * int(c.windowSize)) + maxBatch
|
||||||
|
}
|
||||||
|
cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
||||||
|
c.cells = make([]cacheCell, cacheSize)
|
||||||
|
|
||||||
|
c.DType = dtype
|
||||||
|
c.cellRanges = make(map[int]cellRange)
|
||||||
|
c.backend = backend
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) SetConfig(config ml.CacheConfig) {
|
||||||
|
if c.config != nil {
|
||||||
|
panic("config cannot be changed after being previously set, either by the model or backend")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.config = &config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) Close() {
|
||||||
|
for _, ctx := range c.ctxs {
|
||||||
|
ctx.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||||
|
c.curBatchSize = len(batch.Positions)
|
||||||
|
c.curSequences = batch.Sequences
|
||||||
|
c.curPositions = batch.Positions
|
||||||
|
c.opts.Except = nil
|
||||||
|
|
||||||
|
if !reserve {
|
||||||
|
c.updateSlidingWindow()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
c.curLoc, err = c.findStartLoc()
|
||||||
|
if errors.Is(err, ErrKvCacheFull) {
|
||||||
|
c.defrag()
|
||||||
|
c.curLoc, err = c.findStartLoc()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.curCellRange = newRange()
|
||||||
|
for i, pos := range batch.Positions {
|
||||||
|
seq := batch.Sequences[i]
|
||||||
|
|
||||||
|
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||||
|
|
||||||
|
seqRange, ok := c.cellRanges[seq]
|
||||||
|
if !ok {
|
||||||
|
seqRange = newRange()
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.curLoc+i > seqRange.max {
|
||||||
|
seqRange.max = c.curLoc + i
|
||||||
|
}
|
||||||
|
if seqRange.max > c.curCellRange.max {
|
||||||
|
c.curCellRange.max = seqRange.max
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.curLoc+i < seqRange.min {
|
||||||
|
seqRange.min = c.curLoc + i
|
||||||
|
}
|
||||||
|
if seqRange.min < c.curCellRange.min {
|
||||||
|
c.curCellRange.min = seqRange.min
|
||||||
|
}
|
||||||
|
c.cellRanges[seq] = seqRange
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// If we are reserving memory, don't update any of the cache metadata but set the size
|
||||||
|
// to the worst case.
|
||||||
|
c.curLoc = 0
|
||||||
|
c.curCellRange.min = 0
|
||||||
|
c.curCellRange.max = len(c.cells) - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
c.curMask, err = c.buildMask(ctx)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRange() cellRange {
|
||||||
|
return cellRange{
|
||||||
|
min: math.MaxInt,
|
||||||
|
max: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the first contiguous block of at least curBatchSize
|
||||||
|
func (c *Causal) findStartLoc() (int, error) {
|
||||||
|
var start, count int
|
||||||
|
for i := range c.cells {
|
||||||
|
if len(c.cells[i].sequences) == 0 {
|
||||||
|
count++
|
||||||
|
if count >= c.curBatchSize {
|
||||||
|
return start, nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
start = i + 1
|
||||||
|
count = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) updateSlidingWindow() {
|
||||||
|
if c.windowSize == math.MaxInt32 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a map of unique sequences to the lowest position in that sequence
|
||||||
|
lowestPos := make(map[int]int32)
|
||||||
|
for i := range c.curPositions {
|
||||||
|
seq := c.curSequences[i]
|
||||||
|
|
||||||
|
pos, ok := lowestPos[seq]
|
||||||
|
if !ok {
|
||||||
|
pos = c.curPositions[i]
|
||||||
|
} else if c.curPositions[i] < pos {
|
||||||
|
pos = c.curPositions[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
lowestPos[seq] = pos
|
||||||
|
}
|
||||||
|
|
||||||
|
// delete any entries that are beyond the window of the oldest position in the sequence
|
||||||
|
for seq, pos := range lowestPos {
|
||||||
|
oldRange, ok := c.cellRanges[seq]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
newRange := newRange()
|
||||||
|
|
||||||
|
for i := oldRange.min; i <= oldRange.max; i++ {
|
||||||
|
if slices.Contains(c.cells[i].sequences, seq) {
|
||||||
|
if c.cells[i].pos < pos-c.windowSize {
|
||||||
|
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||||
|
} else {
|
||||||
|
newRange.min = min(newRange.min, i)
|
||||||
|
newRange.max = max(newRange.max, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.cellRanges[seq] = newRange
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func roundDown(length, pad int) int {
|
||||||
|
return (length / pad) * pad
|
||||||
|
}
|
||||||
|
|
||||||
|
func roundUp(length, pad int) int {
|
||||||
|
return ((length + pad - 1) / pad) * pad
|
||||||
|
}
|
||||||
|
|
||||||
|
// Builds a mask of history x batch indicating whether for each token in the batch the
|
||||||
|
// token in the history should apply. This is based on both the sequence and causality (the
|
||||||
|
// position of the history is not ahead of the token in the batch).
|
||||||
|
func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
||||||
|
// Align and pad the two dimensions as required by the backend
|
||||||
|
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
|
||||||
|
|
||||||
|
c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
||||||
|
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||||
|
|
||||||
|
length := c.curCellRange.max - c.curCellRange.min + 1
|
||||||
|
mask := make([]float32, batchSize*length)
|
||||||
|
|
||||||
|
for i := range c.curBatchSize {
|
||||||
|
enabled := !slices.Contains(c.opts.Except, i)
|
||||||
|
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
||||||
|
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
||||||
|
(enabled && c.cells[j].pos > c.curPositions[i]) ||
|
||||||
|
c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
|
||||||
|
c.cells[j].pos < c.curPositions[i]-c.windowSize {
|
||||||
|
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mask out any padding tokens we added. For padding that we added to the cache history, this
|
||||||
|
// has already been masked out because the sequence doesn't match.
|
||||||
|
for i := c.curBatchSize * length; i < len(mask); i++ {
|
||||||
|
mask[i] = float32(math.Inf(-1))
|
||||||
|
}
|
||||||
|
|
||||||
|
maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.config.MaskDType != ml.DTypeF32 {
|
||||||
|
out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
|
||||||
|
ctx.Forward(maskTensor.Copy(ctx, out))
|
||||||
|
maskTensor = out
|
||||||
|
}
|
||||||
|
|
||||||
|
return maskTensor, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
|
||||||
|
for i, key := range c.keys {
|
||||||
|
if key == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
kHeadDim := key.Dim(0)
|
||||||
|
numKVHeads := key.Dim(1)
|
||||||
|
rowSize := key.Stride(2)
|
||||||
|
|
||||||
|
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
|
||||||
|
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
|
||||||
|
|
||||||
|
value := c.values[i]
|
||||||
|
var vSrcView, vDstView ml.Tensor
|
||||||
|
if c.config.PermutedV {
|
||||||
|
vHeadDim := value.Dim(1)
|
||||||
|
elemSize := value.Stride(0)
|
||||||
|
|
||||||
|
vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
|
||||||
|
vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
|
||||||
|
} else {
|
||||||
|
vHeadDim := value.Dim(0)
|
||||||
|
rowSize := value.Stride(2)
|
||||||
|
|
||||||
|
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
|
||||||
|
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Forward(
|
||||||
|
kSrcView.Copy(ctx, kDstView),
|
||||||
|
vSrcView.Copy(ctx, vDstView),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) defrag() {
|
||||||
|
slog.Debug("defragmenting kv cache")
|
||||||
|
|
||||||
|
// Defrag strategy:
|
||||||
|
// - Search for empty holes at the beginning of the cache,
|
||||||
|
// filling them with active data starting at the end
|
||||||
|
// - If there are contiguous elements that need to be moved,
|
||||||
|
// combine them into a single operation by holding new moves
|
||||||
|
// until we see that the next one is non-contiguous
|
||||||
|
// - Fill up the context with the maximum number of operations it
|
||||||
|
// can hold then compute that and continue with a new context
|
||||||
|
//
|
||||||
|
// We could try to optimize placement by grouping blocks from
|
||||||
|
// the same sequences together but most likely the next forward
|
||||||
|
// pass will disrupt this anyways, so the real world benefit
|
||||||
|
// seems limited as this time.
|
||||||
|
|
||||||
|
ctx := c.backend.NewContext()
|
||||||
|
|
||||||
|
// For every move, 6 tensors are required per layer (2 views and a
|
||||||
|
// copy for each of k and v). We also need to refer to the original
|
||||||
|
// k and v cache tensors - once per layer, not per move.
|
||||||
|
layers := 0
|
||||||
|
for _, key := range c.keys {
|
||||||
|
if key == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
layers++
|
||||||
|
}
|
||||||
|
|
||||||
|
maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
|
||||||
|
moves := 0
|
||||||
|
|
||||||
|
var pendingSrc, pendingDst, pendingLen int
|
||||||
|
src := len(c.cells) - 1
|
||||||
|
|
||||||
|
for dst := 0; dst < src; dst++ {
|
||||||
|
if len(c.cells[dst].sequences) == 0 {
|
||||||
|
for ; src > dst; src-- {
|
||||||
|
if len(c.cells[src].sequences) != 0 {
|
||||||
|
c.cells[dst] = c.cells[src]
|
||||||
|
c.cells[src] = cacheCell{}
|
||||||
|
|
||||||
|
if pendingLen > 0 {
|
||||||
|
if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen {
|
||||||
|
pendingSrc = src
|
||||||
|
pendingLen++
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
|
||||||
|
moves++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pendingSrc = src
|
||||||
|
pendingDst = dst
|
||||||
|
pendingLen = 1
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if moves >= maxMoves {
|
||||||
|
ctx.Compute()
|
||||||
|
ctx.Close()
|
||||||
|
ctx = c.backend.NewContext()
|
||||||
|
|
||||||
|
moves = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if pendingLen > 0 {
|
||||||
|
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
|
||||||
|
moves++
|
||||||
|
}
|
||||||
|
|
||||||
|
if moves > 0 {
|
||||||
|
ctx.Compute()
|
||||||
|
}
|
||||||
|
ctx.Close()
|
||||||
|
|
||||||
|
// Reset range metadata
|
||||||
|
for seq := range c.cellRanges {
|
||||||
|
seqRange := newRange()
|
||||||
|
|
||||||
|
for i, cell := range c.cells {
|
||||||
|
if slices.Contains(cell.sequences, seq) {
|
||||||
|
if i < seqRange.min {
|
||||||
|
seqRange.min = i
|
||||||
|
}
|
||||||
|
if i > seqRange.max {
|
||||||
|
seqRange.max = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.cellRanges[seq] = seqRange
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) SetLayer(layer int) {
|
||||||
|
c.curLayer = layer
|
||||||
|
}
|
||||||
|
|
||||||
|
type CausalOptions struct {
|
||||||
|
// Enabled controls whether the causal mask is generated for a particular index in a batch
|
||||||
|
Except []int
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCausal disables causal mask generation for a particular range of indicies in
|
||||||
|
// the current batch for subsequent calls to Get. The state resets for the next forward pass.
|
||||||
|
func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
|
||||||
|
if !slices.Equal(c.opts.Except, opts.Except) {
|
||||||
|
c.opts = opts
|
||||||
|
if ctx != nil {
|
||||||
|
var err error
|
||||||
|
c.curMask, err = c.buildMask(ctx)
|
||||||
|
if err != nil {
|
||||||
|
// This error should never occur because we have previously built a mask with the same shape
|
||||||
|
panic(fmt.Errorf("SetCausal: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||||
|
key := c.keys[c.curLayer]
|
||||||
|
value := c.values[c.curLayer]
|
||||||
|
|
||||||
|
kHeadDim := key.Dim(0)
|
||||||
|
numKVHeads := key.Dim(1)
|
||||||
|
rowSize := key.Stride(2)
|
||||||
|
cachedSize := c.curMask.Dim(0)
|
||||||
|
|
||||||
|
key = key.View(ctx, rowSize*c.curCellRange.min,
|
||||||
|
kHeadDim, key.Stride(1),
|
||||||
|
numKVHeads, key.Stride(2),
|
||||||
|
cachedSize,
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.config.PermutedV {
|
||||||
|
vHeadDim := value.Dim(1)
|
||||||
|
elemSize := value.Stride(0)
|
||||||
|
|
||||||
|
value = value.View(ctx, elemSize*c.curCellRange.min,
|
||||||
|
cachedSize, value.Stride(1),
|
||||||
|
vHeadDim, value.Stride(2),
|
||||||
|
numKVHeads,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
vHeadDim := value.Dim(0)
|
||||||
|
rowSize := value.Stride(2)
|
||||||
|
|
||||||
|
value = value.View(ctx, rowSize*c.curCellRange.min,
|
||||||
|
vHeadDim, value.Stride(1),
|
||||||
|
numKVHeads, value.Stride(2),
|
||||||
|
cachedSize,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, value, c.curMask
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||||
|
kHeadDim := key.Dim(0)
|
||||||
|
vHeadDim := value.Dim(0)
|
||||||
|
numKVHeads := key.Dim(1)
|
||||||
|
batchSize := key.Dim(2)
|
||||||
|
|
||||||
|
if c.curBatchSize != batchSize {
|
||||||
|
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||||
|
c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := c.keys[c.curLayer]; !ok {
|
||||||
|
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, len(c.cells))
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := c.values[c.curLayer]; !ok {
|
||||||
|
if c.config.PermutedV {
|
||||||
|
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vHeadDim, numKVHeads)
|
||||||
|
} else {
|
||||||
|
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, len(c.cells))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rowSize := c.keys[c.curLayer].Stride(2)
|
||||||
|
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
|
||||||
|
|
||||||
|
if c.config.PermutedV {
|
||||||
|
elemSize := c.values[c.curLayer].Stride(0)
|
||||||
|
|
||||||
|
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||||
|
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
|
||||||
|
} else {
|
||||||
|
rowSize := c.values[c.curLayer].Stride(2)
|
||||||
|
|
||||||
|
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||||
|
seqRange := newRange()
|
||||||
|
|
||||||
|
for i := range c.cells {
|
||||||
|
// Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end
|
||||||
|
if slices.Contains(c.cells[i].sequences, dstSeq) {
|
||||||
|
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq })
|
||||||
|
}
|
||||||
|
|
||||||
|
if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len {
|
||||||
|
c.cells[i].sequences = append(c.cells[i].sequences, dstSeq)
|
||||||
|
if i < seqRange.min {
|
||||||
|
seqRange.min = i
|
||||||
|
}
|
||||||
|
if i > seqRange.max {
|
||||||
|
seqRange.max = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.cellRanges[dstSeq] = seqRange
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||||
|
if c.windowSize == math.MaxInt32 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
seqRange, ok := c.cellRanges[seq]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// for sliding window, check that the window of the new sequence is contained in
|
||||||
|
// the window of what we are storing
|
||||||
|
var last int32 = -1
|
||||||
|
for i := seqRange.min; i <= seqRange.max; i++ {
|
||||||
|
if slices.Contains(c.cells[i].sequences, seq) {
|
||||||
|
last = max(last, c.cells[i].pos)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if last == -1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
lastWindowStart := max(0, last-c.windowSize)
|
||||||
|
posWindowStart := max(0, pos-c.windowSize)
|
||||||
|
|
||||||
|
return posWindowStart >= lastWindowStart
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||||
|
if c.shiftFn == nil {
|
||||||
|
return ErrNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := c.backend.NewContext()
|
||||||
|
defer ctx.Close()
|
||||||
|
|
||||||
|
seqRange := c.cellRanges[seq]
|
||||||
|
size := seqRange.max - seqRange.min + 1
|
||||||
|
|
||||||
|
offsets := make([]int32, size)
|
||||||
|
for i := range offsets {
|
||||||
|
cell := c.cells[seqRange.min+i]
|
||||||
|
|
||||||
|
if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
|
||||||
|
offsets[i] = offset
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kShift, err := ctx.Input().FromIntSlice(offsets, len(offsets))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, key := range c.keys {
|
||||||
|
if key == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
kHeadDim := key.Dim(0)
|
||||||
|
numKVHeads := key.Dim(1)
|
||||||
|
rowSize := key.Stride(2)
|
||||||
|
|
||||||
|
key = key.View(ctx, rowSize*seqRange.min,
|
||||||
|
kHeadDim, key.Stride(1),
|
||||||
|
numKVHeads, key.Stride(2),
|
||||||
|
size,
|
||||||
|
)
|
||||||
|
|
||||||
|
roped, err := c.shiftFn(ctx, i, key, kShift)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Forward(roped.Copy(ctx, key))
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Compute()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
|
||||||
|
// TODO(jessegross): We should check to see if removing the middle of the sequence will
|
||||||
|
// cause the sliding window to encompass tokens that we no longer have. If so, then we
|
||||||
|
// should return an error, which will trigger the runner to evaluate the full history and
|
||||||
|
// rebuild the window. However, if we have multimodal inputs in our history, this reuse
|
||||||
|
// results in use after free, so we don't do it for now.
|
||||||
|
|
||||||
|
var offset int32
|
||||||
|
if endIndex != math.MaxInt32 {
|
||||||
|
offset = beginIndex - endIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
seqRange := newRange()
|
||||||
|
|
||||||
|
for i := range c.cells {
|
||||||
|
if slices.Contains(c.cells[i].sequences, seq) {
|
||||||
|
if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex {
|
||||||
|
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||||
|
} else {
|
||||||
|
if c.cells[i].pos >= endIndex {
|
||||||
|
if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
|
||||||
|
return errors.New("shifting cells shared by multiple sequences not supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.cells[i].pos += offset
|
||||||
|
}
|
||||||
|
if i < seqRange.min {
|
||||||
|
seqRange.min = i
|
||||||
|
}
|
||||||
|
if i > seqRange.max {
|
||||||
|
seqRange.max = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if seqRange == newRange() {
|
||||||
|
delete(c.cellRanges, seq)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.cellRanges[seq] = seqRange
|
||||||
|
|
||||||
|
if endIndex != math.MaxInt32 {
|
||||||
|
err := c.shift(seq, endIndex+offset, offset)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
598
kvcache/causal_test.go
Normal file
598
kvcache/causal_test.go
Normal file
@@ -0,0 +1,598 @@
|
|||||||
|
package kvcache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
in []float32
|
||||||
|
inShape []int
|
||||||
|
seqs []int
|
||||||
|
pos []int32
|
||||||
|
expected []float32
|
||||||
|
expectedShape []int
|
||||||
|
expectedMask []float32
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStore(t *testing.T) {
|
||||||
|
backend := &testBackend{}
|
||||||
|
cache := NewCausalCache(nil)
|
||||||
|
defer cache.Close()
|
||||||
|
|
||||||
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
name: "FirstBatch",
|
||||||
|
in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
|
||||||
|
inShape: []int{2, 3, 4},
|
||||||
|
seqs: []int{0, 0, 0, 0},
|
||||||
|
pos: []int32{0, 1, 2, 3},
|
||||||
|
expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
|
||||||
|
expectedShape: []int{2, 3, 4},
|
||||||
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SecondBatch",
|
||||||
|
in: []float32{115, 215, 125, 225, 135, 235},
|
||||||
|
inShape: []int{2, 3, 1},
|
||||||
|
seqs: []int{0},
|
||||||
|
pos: []int32{4},
|
||||||
|
expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
|
||||||
|
expectedShape: []int{2, 3, 5},
|
||||||
|
expectedMask: []float32{0, 0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCache(t, backend, cache, tests)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSWA(t *testing.T) {
|
||||||
|
backend := &testBackend{}
|
||||||
|
cache := NewSWACache(1, nil)
|
||||||
|
defer cache.Close()
|
||||||
|
|
||||||
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
name: "FirstBatch",
|
||||||
|
in: []float32{1, 2, 3, 4},
|
||||||
|
inShape: []int{1, 1, 4},
|
||||||
|
seqs: []int{0, 0, 0, 0},
|
||||||
|
pos: []int32{0, 1, 2, 3},
|
||||||
|
expected: []float32{1, 2, 3, 4},
|
||||||
|
expectedShape: []int{1, 1, 4},
|
||||||
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SecondBatch",
|
||||||
|
in: []float32{5, 6},
|
||||||
|
inShape: []int{1, 1, 2},
|
||||||
|
seqs: []int{0, 0},
|
||||||
|
pos: []int32{4, 5},
|
||||||
|
expected: []float32{5, 6, 3, 4},
|
||||||
|
expectedShape: []int{1, 1, 4},
|
||||||
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCache(t, backend, cache, tests)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChunkedAttention(t *testing.T) {
|
||||||
|
cache := NewChunkedAttentionCache(2, nil)
|
||||||
|
defer cache.Close()
|
||||||
|
|
||||||
|
var b testBackend
|
||||||
|
cache.Init(&b, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
x := float32(math.Inf(-1))
|
||||||
|
|
||||||
|
testCache(
|
||||||
|
t, &b, cache,
|
||||||
|
[]testCase{
|
||||||
|
{
|
||||||
|
name: "FirstBatch",
|
||||||
|
in: []float32{1, 2, 3, 4},
|
||||||
|
inShape: []int{1, 1, 4},
|
||||||
|
seqs: []int{0, 0, 0, 0},
|
||||||
|
pos: []int32{0, 1, 2, 3},
|
||||||
|
expected: []float32{1, 2, 3, 4},
|
||||||
|
expectedShape: []int{1, 1, 4},
|
||||||
|
expectedMask: []float32{
|
||||||
|
0, x, x, x,
|
||||||
|
0, 0, x, x,
|
||||||
|
x, x, 0, x,
|
||||||
|
x, x, 0, 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SecondBatch",
|
||||||
|
in: []float32{5, 6, 7},
|
||||||
|
inShape: []int{1, 1, 3},
|
||||||
|
seqs: []int{0, 0, 0},
|
||||||
|
pos: []int32{4, 5, 6},
|
||||||
|
expected: []float32{1, 2, 3, 4, 5, 6, 7},
|
||||||
|
expectedShape: []int{1, 1, 7},
|
||||||
|
expectedMask: []float32{
|
||||||
|
x, x, x, x, 0, x, x,
|
||||||
|
x, x, x, x, 0, 0, x,
|
||||||
|
x, x, x, x, x, x, 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ThirdBatch",
|
||||||
|
in: []float32{8, 9},
|
||||||
|
inShape: []int{1, 1, 2},
|
||||||
|
seqs: []int{0, 0},
|
||||||
|
pos: []int32{7, 8},
|
||||||
|
expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
|
||||||
|
expectedShape: []int{1, 1, 9},
|
||||||
|
expectedMask: []float32{
|
||||||
|
x, x, x, x, x, x, 0, 0, x,
|
||||||
|
x, x, x, x, x, x, x, x, 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSequences(t *testing.T) {
|
||||||
|
backend := &testBackend{}
|
||||||
|
cache := NewCausalCache(nil)
|
||||||
|
defer cache.Close()
|
||||||
|
|
||||||
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
name: "FirstBatch",
|
||||||
|
in: []float32{1, 2, 3, 4},
|
||||||
|
inShape: []int{1, 1, 4},
|
||||||
|
seqs: []int{0, 0, 1, 1},
|
||||||
|
pos: []int32{0, 1, 0, 1},
|
||||||
|
expected: []float32{1, 2, 3, 4},
|
||||||
|
expectedShape: []int{1, 1, 4},
|
||||||
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SecondBatch",
|
||||||
|
in: []float32{5, 6},
|
||||||
|
inShape: []int{1, 1, 2},
|
||||||
|
seqs: []int{0, 1},
|
||||||
|
pos: []int32{2, 2},
|
||||||
|
expected: []float32{1, 2, 3, 4, 5, 6},
|
||||||
|
expectedShape: []int{1, 1, 6},
|
||||||
|
expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCache(t, backend, cache, tests)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemove(t *testing.T) {
|
||||||
|
backend := &testBackend{}
|
||||||
|
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
|
return key.Add(ctx, shift), nil
|
||||||
|
})
|
||||||
|
defer cache.Close()
|
||||||
|
|
||||||
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
name: "FirstBatch",
|
||||||
|
in: []float32{1, 2, 3, 4},
|
||||||
|
inShape: []int{1, 1, 4},
|
||||||
|
seqs: []int{0, 0, 1, 1},
|
||||||
|
pos: []int32{0, 1, 0, 1},
|
||||||
|
expected: []float32{1, 2, 3, 4},
|
||||||
|
expectedShape: []int{1, 1, 4},
|
||||||
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCache(t, backend, cache, tests)
|
||||||
|
|
||||||
|
err := cache.Remove(0, 1, math.MaxInt32)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests = []testCase{
|
||||||
|
{
|
||||||
|
name: "RemoveEnd",
|
||||||
|
in: []float32{5, 6},
|
||||||
|
inShape: []int{1, 1, 2},
|
||||||
|
seqs: []int{0, 1},
|
||||||
|
pos: []int32{1, 2},
|
||||||
|
expected: []float32{1, 2, 3, 4, 5, 6},
|
||||||
|
expectedShape: []int{1, 1, 6},
|
||||||
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCache(t, backend, cache, tests)
|
||||||
|
|
||||||
|
err = cache.Remove(0, 0, 1)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests = []testCase{
|
||||||
|
{
|
||||||
|
name: "RemoveMiddle",
|
||||||
|
in: []float32{7, 8},
|
||||||
|
inShape: []int{1, 1, 2},
|
||||||
|
seqs: []int{0, 0},
|
||||||
|
pos: []int32{1, 2},
|
||||||
|
expected: []float32{7, 8, 3, 4, 4},
|
||||||
|
expectedShape: []int{1, 1, 5},
|
||||||
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCache(t, backend, cache, tests)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefrag(t *testing.T) {
|
||||||
|
backend := &testBackend{}
|
||||||
|
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
|
return key.Add(ctx, shift), nil
|
||||||
|
})
|
||||||
|
defer cache.Close()
|
||||||
|
|
||||||
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
name: "FirstBatch",
|
||||||
|
in: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||||
|
inShape: []int{1, 1, 16},
|
||||||
|
seqs: []int{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||||
|
pos: []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
|
||||||
|
expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||||
|
expectedShape: []int{1, 1, 16},
|
||||||
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCache(t, backend, cache, tests)
|
||||||
|
|
||||||
|
err := cache.Remove(0, 2, 4)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = cache.Remove(0, 13, math.MaxInt32)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests = []testCase{
|
||||||
|
{
|
||||||
|
name: "Defrag",
|
||||||
|
in: []float32{17, 18, 19},
|
||||||
|
inShape: []int{1, 1, 3},
|
||||||
|
seqs: []int{0, 0, 0},
|
||||||
|
pos: []int32{16, 17, 18},
|
||||||
|
expected: []float32{1, 2, 12, 13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 17, 18, 19},
|
||||||
|
expectedShape: []int{1, 1, 16},
|
||||||
|
expectedMask: []float32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCache(t, backend, cache, tests)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopy(t *testing.T) {
|
||||||
|
backend := &testBackend{}
|
||||||
|
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
|
||||||
|
defer cache.Close()
|
||||||
|
|
||||||
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
name: "FirstBatch",
|
||||||
|
in: []float32{1, 2, 3, 4},
|
||||||
|
inShape: []int{1, 1, 4},
|
||||||
|
seqs: []int{0, 0, 0, 0},
|
||||||
|
pos: []int32{0, 1, 2, 3},
|
||||||
|
expected: []float32{1, 2, 3, 4},
|
||||||
|
expectedShape: []int{1, 1, 4},
|
||||||
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCache(t, backend, cache, tests)
|
||||||
|
|
||||||
|
cache.CopyPrefix(0, 1, 2)
|
||||||
|
|
||||||
|
tests = []testCase{
|
||||||
|
{
|
||||||
|
name: "Copy",
|
||||||
|
in: []float32{5, 6},
|
||||||
|
inShape: []int{1, 1, 2},
|
||||||
|
seqs: []int{1, 1},
|
||||||
|
pos: []int32{3, 4},
|
||||||
|
expected: []float32{1, 2, 3, 4, 5, 6},
|
||||||
|
expectedShape: []int{1, 1, 6},
|
||||||
|
expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCache(t, backend, cache, tests)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
context := backend.NewContext()
|
||||||
|
defer context.Close()
|
||||||
|
|
||||||
|
err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cache.SetLayer(0)
|
||||||
|
tensor, _ := context.FromFloatSlice(test.in, test.inShape...)
|
||||||
|
cache.Put(context, tensor, tensor)
|
||||||
|
|
||||||
|
out, _, mask := cache.Get(context)
|
||||||
|
|
||||||
|
context.Forward(out, mask).Compute(out, mask)
|
||||||
|
|
||||||
|
if !slices.Equal(out.Floats(), test.expected) {
|
||||||
|
t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !slices.Equal(out.Shape(), test.expectedShape) {
|
||||||
|
t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !slices.Equal(mask.Floats(), test.expectedMask) {
|
||||||
|
t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCanResume(t *testing.T) {
|
||||||
|
backend := &testBackend{}
|
||||||
|
windowSize := int32(4)
|
||||||
|
cache := NewSWACache(windowSize, nil)
|
||||||
|
defer cache.Close()
|
||||||
|
|
||||||
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
context := backend.NewContext()
|
||||||
|
defer context.Close()
|
||||||
|
|
||||||
|
err := cache.StartForward(context, input.Batch{
|
||||||
|
Positions: []int32{0, 1, 2, 3},
|
||||||
|
Sequences: []int{0, 0, 0, 0},
|
||||||
|
}, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("StartForward failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cache.SetLayer(0)
|
||||||
|
tensor, _ := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
|
||||||
|
cache.Put(context, tensor, tensor)
|
||||||
|
|
||||||
|
// with window size 4, nothing has slid out of the window yet
|
||||||
|
if !cache.CanResume(0, 0) {
|
||||||
|
t.Errorf("CanResume(0, 0) = false, want true (within window)")
|
||||||
|
}
|
||||||
|
if !cache.CanResume(0, 1) {
|
||||||
|
t.Errorf("CanResume(0, 1) = false, want true (within window)")
|
||||||
|
}
|
||||||
|
if !cache.CanResume(0, 2) {
|
||||||
|
t.Errorf("CanResume(0, 2) = false, want true (within window)")
|
||||||
|
}
|
||||||
|
if !cache.CanResume(0, 3) {
|
||||||
|
t.Errorf("CanResume(0, 3) = false, want true (latest position)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// shift window by adding position 4
|
||||||
|
err = cache.StartForward(context, input.Batch{
|
||||||
|
Positions: []int32{4, 5},
|
||||||
|
Sequences: []int{0, 0},
|
||||||
|
}, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("StartForward failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cache.SetLayer(0)
|
||||||
|
tensor, _ = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
|
||||||
|
cache.Put(context, tensor, tensor)
|
||||||
|
|
||||||
|
// only the latest position has overlapping windows
|
||||||
|
if cache.CanResume(0, 0) {
|
||||||
|
t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
|
||||||
|
}
|
||||||
|
if cache.CanResume(0, 1) {
|
||||||
|
t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
|
||||||
|
}
|
||||||
|
if cache.CanResume(0, 2) {
|
||||||
|
t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
|
||||||
|
}
|
||||||
|
if cache.CanResume(0, 3) {
|
||||||
|
t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
|
||||||
|
}
|
||||||
|
if cache.CanResume(0, 4) {
|
||||||
|
t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
|
||||||
|
}
|
||||||
|
if !cache.CanResume(0, 5) {
|
||||||
|
t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type testBackend struct {
|
||||||
|
ml.Backend
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *testBackend) NewContext() ml.Context {
|
||||||
|
return &testContext{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *testBackend) NewContextSize(int) ml.Context {
|
||||||
|
return &testContext{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type testContext struct {
|
||||||
|
ml.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
||||||
|
total := 0
|
||||||
|
|
||||||
|
if len(shape) > 0 {
|
||||||
|
total = 1
|
||||||
|
for _, s := range shape {
|
||||||
|
total *= s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||||
|
return c.Empty(dtype, shape...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
||||||
|
t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
|
||||||
|
|
||||||
|
copy(t.data, s)
|
||||||
|
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
||||||
|
f := make([]float32, len(s))
|
||||||
|
for i := range f {
|
||||||
|
f[i] = float32(s[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
out, _ := c.FromFloatSlice(f, shape...)
|
||||||
|
out.(*testTensor).dtype = ml.DTypeI32
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
|
||||||
|
s := make([]float32, 0, int((stop-start)/step))
|
||||||
|
for i := start; i < stop; i += step {
|
||||||
|
s = append(s, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, _ := c.FromFloatSlice(s, len(s))
|
||||||
|
out.(*testTensor).dtype = dtype
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testContext) Input() ml.Context { return c }
|
||||||
|
func (c *testContext) Layer(int) ml.Context { return c }
|
||||||
|
|
||||||
|
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
||||||
|
|
||||||
|
func (c *testContext) Compute(...ml.Tensor) {}
|
||||||
|
|
||||||
|
func (c *testContext) Reserve() error { return nil }
|
||||||
|
|
||||||
|
func (c *testContext) MaxGraphNodes() int {
|
||||||
|
return 10
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *testContext) Close() {}
|
||||||
|
|
||||||
|
type testTensor struct {
|
||||||
|
ml.Tensor
|
||||||
|
|
||||||
|
dtype ml.DType
|
||||||
|
elementSize int
|
||||||
|
data []float32
|
||||||
|
shape []int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Dim(n int) int {
|
||||||
|
return t.shape[n]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Stride(n int) int {
|
||||||
|
stride := t.elementSize
|
||||||
|
for i := range n {
|
||||||
|
stride *= t.shape[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
return stride
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Shape() []int {
|
||||||
|
return t.shape
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) DType() ml.DType {
|
||||||
|
return t.dtype
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Floats() []float32 {
|
||||||
|
out := make([]float32, len(t.data))
|
||||||
|
copy(out, t.data)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Neg(ctx ml.Context) ml.Tensor {
|
||||||
|
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
|
||||||
|
for i := range out.data {
|
||||||
|
out.data[i] = -t.data[i]
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
|
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
|
||||||
|
|
||||||
|
for i := range out.data {
|
||||||
|
out.data[i] = t.data[i] + t2.(*testTensor).data[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||||
|
offset /= t.elementSize
|
||||||
|
|
||||||
|
var s []int
|
||||||
|
|
||||||
|
switch len(shape) {
|
||||||
|
case 1:
|
||||||
|
s = []int{shape[0]}
|
||||||
|
case 5:
|
||||||
|
s = []int{shape[0], shape[2], shape[4]}
|
||||||
|
default:
|
||||||
|
panic("unsupported number of dimensions")
|
||||||
|
}
|
||||||
|
|
||||||
|
context := &testContext{}
|
||||||
|
|
||||||
|
view := context.Empty(t.dtype, s...).(*testTensor)
|
||||||
|
view.data = t.data[offset : offset+len(view.data)]
|
||||||
|
|
||||||
|
return view
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
|
copy(t2.(*testTensor).data, t.data)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
156
kvcache/encoder.go
Normal file
156
kvcache/encoder.go
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
package kvcache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Encoder cache stores K and V tensors that are position independent
|
||||||
|
//
|
||||||
|
// The tensors can be of any shape and will be returned as they were stored
|
||||||
|
// The mask is currently always nil
|
||||||
|
//
|
||||||
|
// Not currently safe for multiple sequences
|
||||||
|
type EncoderCache struct {
|
||||||
|
// config controls mostly backend-specific optimizations
|
||||||
|
config *ml.CacheConfig
|
||||||
|
|
||||||
|
// ** current forward pass **
|
||||||
|
|
||||||
|
// the active layer for Get and Put
|
||||||
|
curLayer int
|
||||||
|
|
||||||
|
// if something is stored during this pass, this
|
||||||
|
// will be the position (but there is no guarantee
|
||||||
|
// anything will be stored)
|
||||||
|
curPos int32
|
||||||
|
|
||||||
|
// curReserve indicates that this forward pass is only for
|
||||||
|
// memory reservation and we should not update our metadata
|
||||||
|
// based on it.
|
||||||
|
curReserve bool
|
||||||
|
|
||||||
|
// ** cache metadata **
|
||||||
|
|
||||||
|
// was something stored in the cache?
|
||||||
|
encoderCached bool
|
||||||
|
|
||||||
|
// position of the cached data
|
||||||
|
encoderPos int32
|
||||||
|
|
||||||
|
// ** cache data storage **
|
||||||
|
backend ml.Backend
|
||||||
|
ctxs map[int]ml.Context
|
||||||
|
keys, values map[int]ml.Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewEncoderCache() *EncoderCache {
|
||||||
|
return &EncoderCache{
|
||||||
|
ctxs: make(map[int]ml.Context),
|
||||||
|
keys: make(map[int]ml.Tensor),
|
||||||
|
values: make(map[int]ml.Tensor),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||||
|
if c.config == nil {
|
||||||
|
var config ml.CacheConfig
|
||||||
|
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||||
|
config = cc.CacheConfig()
|
||||||
|
}
|
||||||
|
c.config = &config
|
||||||
|
}
|
||||||
|
|
||||||
|
if maxSequences > 1 {
|
||||||
|
panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
|
||||||
|
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
||||||
|
}
|
||||||
|
|
||||||
|
c.backend = backend
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
|
||||||
|
if c.config != nil {
|
||||||
|
panic("config cannot be changed after being previously set, either by the model or backend")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.config = &config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *EncoderCache) Close() {
|
||||||
|
for _, ctx := range c.ctxs {
|
||||||
|
ctx.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||||
|
// We work with the most recent image
|
||||||
|
if len(batch.Multimodal) > 0 {
|
||||||
|
c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
|
||||||
|
}
|
||||||
|
|
||||||
|
c.curReserve = reserve
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *EncoderCache) SetLayer(layer int) {
|
||||||
|
c.curLayer = layer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *EncoderCache) EncoderCached() bool {
|
||||||
|
return c.encoderCached
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||||
|
return c.keys[c.curLayer], c.values[c.curLayer], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||||
|
if !c.curReserve {
|
||||||
|
c.encoderPos = c.curPos
|
||||||
|
c.encoderCached = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.config.PermutedV {
|
||||||
|
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := c.ctxs[c.curLayer]; !ok {
|
||||||
|
c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := c.keys[c.curLayer]; !ok {
|
||||||
|
c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := c.values[c.curLayer]; !ok {
|
||||||
|
c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Forward(
|
||||||
|
key.Copy(ctx, c.keys[c.curLayer]),
|
||||||
|
value.Copy(ctx, c.values[c.curLayer]),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||||
|
panic("encoder cache does not support multiple sequences")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *EncoderCache) CanResume(seq int, pos int32) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||||
|
if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
|
||||||
|
c.encoderCached = false
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
110
kvcache/wrapper.go
Normal file
110
kvcache/wrapper.go
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
package kvcache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Wrapper cache is a container for multiple types of caches,
|
||||||
|
// such as for the encoding and decoding portions of a model.
|
||||||
|
type WrapperCache struct {
|
||||||
|
// caches we are wrapping
|
||||||
|
caches []Cache
|
||||||
|
|
||||||
|
// cache to be used for this layer
|
||||||
|
curType int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWrapperCache(caches ...Cache) *WrapperCache {
|
||||||
|
return &WrapperCache{
|
||||||
|
caches: caches,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||||
|
for _, cache := range c.caches {
|
||||||
|
cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
|
||||||
|
for _, cache := range c.caches {
|
||||||
|
cache.SetConfig(config)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WrapperCache) Close() {
|
||||||
|
for _, cache := range c.caches {
|
||||||
|
cache.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||||
|
for i, cache := range c.caches {
|
||||||
|
err := cache.StartForward(ctx, batch, reserve)
|
||||||
|
if err != nil {
|
||||||
|
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
|
||||||
|
for j := i - 1; j >= 0; j-- {
|
||||||
|
for k := range batch.Positions {
|
||||||
|
_ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.curType = 0
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WrapperCache) SetLayer(layer int) {
|
||||||
|
for _, cache := range c.caches {
|
||||||
|
cache.SetLayer(layer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WrapperCache) SetLayerType(layerType int) {
|
||||||
|
c.curType = layerType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WrapperCache) UnderlyingCache() Cache {
|
||||||
|
return c.caches[c.curType]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||||
|
return c.caches[c.curType].Get(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||||
|
c.caches[c.curType].Put(ctx, key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||||
|
for _, cache := range c.caches {
|
||||||
|
cache.CopyPrefix(srcSeq, dstSeq, len)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WrapperCache) CanResume(seq int, pos int32) bool {
|
||||||
|
for _, cache := range c.caches {
|
||||||
|
if !cache.CanResume(seq, pos) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||||
|
// If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
|
||||||
|
for _, cache := range c.caches {
|
||||||
|
err := cache.Remove(seq, beginIndex, endIndex)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
137
llama/README.md
137
llama/README.md
@@ -1,157 +1,52 @@
|
|||||||
# `llama`
|
# `llama`
|
||||||
|
|
||||||
This package integrates the [llama.cpp](https://github.com/ggerganov/llama.cpp) library as a Go package and makes it easy to build it with tags for different CPU and GPU processors.
|
This package provides Go bindings to [llama.cpp](https://github.com/ggerganov/llama.cpp).
|
||||||
|
|
||||||
Supported:
|
|
||||||
|
|
||||||
- [x] CPU
|
|
||||||
- [x] avx, avx2
|
|
||||||
- [x] macOS Metal
|
|
||||||
- [x] Windows CUDA
|
|
||||||
- [x] Windows ROCm
|
|
||||||
- [x] Linux CUDA
|
|
||||||
- [x] Linux ROCm
|
|
||||||
- [x] Llava
|
|
||||||
|
|
||||||
Extra build steps are required for CUDA and ROCm on Windows since `nvcc` and `hipcc` both require using msvc as the host compiler. For these shared libraries are created:
|
|
||||||
|
|
||||||
- `ggml_cuda.dll` on Windows or `ggml_cuda.so` on Linux
|
|
||||||
- `ggml_hipblas.dll` on Windows or `ggml_hipblas.so` on Linux
|
|
||||||
|
|
||||||
> Note: it's important that memory is allocated and freed by the same compiler (e.g. entirely by code compiled with msvc or mingw). Issues from this should be rare, but there are some places where pointers are returned by the CUDA or HIP runtimes and freed elsewhere, causing a a crash. In a future change the same runtime should be used in both cases to avoid crashes.
|
|
||||||
|
|
||||||
## Building
|
|
||||||
|
|
||||||
```
|
|
||||||
go build .
|
|
||||||
```
|
|
||||||
|
|
||||||
### AVX
|
|
||||||
|
|
||||||
```shell
|
|
||||||
go build -tags avx .
|
|
||||||
```
|
|
||||||
|
|
||||||
### AVX2
|
|
||||||
|
|
||||||
```shell
|
|
||||||
# go doesn't recognize `-mfma` as a valid compiler flag
|
|
||||||
# see https://github.com/golang/go/issues/17895
|
|
||||||
go env -w "CGO_CFLAGS_ALLOW=-mfma|-mf16c"
|
|
||||||
go env -w "CGO_CXXFLAGS_ALLOW=-mfma|-mf16c"
|
|
||||||
go build -tags=avx,avx2 .
|
|
||||||
```
|
|
||||||
|
|
||||||
## Linux
|
|
||||||
|
|
||||||
### CUDA
|
|
||||||
|
|
||||||
Install the [CUDA toolkit v11.3.1](https://developer.nvidia.com/cuda-11-3-1-download-archive):
|
|
||||||
|
|
||||||
```shell
|
|
||||||
make ggml_cuda.so
|
|
||||||
go build -tags avx,cuda .
|
|
||||||
```
|
|
||||||
|
|
||||||
### ROCm
|
|
||||||
|
|
||||||
Install [ROCm](https://rocm.docs.amd.com/en/latest/).
|
|
||||||
|
|
||||||
```shell
|
|
||||||
make ggml_hipblas.so
|
|
||||||
go build -tags avx,rocm .
|
|
||||||
```
|
|
||||||
|
|
||||||
## Windows
|
|
||||||
|
|
||||||
Download [w64devkit](https://github.com/skeeto/w64devkit/releases/latest) for a simple MinGW development environment.
|
|
||||||
|
|
||||||
### CUDA
|
|
||||||
|
|
||||||
Install the [CUDA toolkit v11.3.1](https://developer.nvidia.com/cuda-11-3-1-download-archive) then build the cuda code:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
make ggml_cuda.dll
|
|
||||||
go build -tags avx,cuda .
|
|
||||||
```
|
|
||||||
|
|
||||||
### ROCm
|
|
||||||
|
|
||||||
Install [ROCm](https://rocm.docs.amd.com/en/latest/).
|
|
||||||
|
|
||||||
```shell
|
|
||||||
make ggml_hipblas.dll
|
|
||||||
go build -tags avx,rocm .
|
|
||||||
```
|
|
||||||
|
|
||||||
## Building runners
|
|
||||||
|
|
||||||
```shell
|
|
||||||
# build all runners for this platform
|
|
||||||
make -j
|
|
||||||
```
|
|
||||||
|
|
||||||
## Vendoring
|
## Vendoring
|
||||||
|
|
||||||
Ollama currently vendors [llama.cpp](https://github.com/ggerganov/llama.cpp/) and [ggml](https://github.com/ggerganov/ggml) through a vendoring model. While we generally strive to contribute changes back upstream to avoid drift, we cary a small set of patches which are applied to the tracking commit. A set of make targets are available to aid developers in updating to a newer tracking commit, or to work on changes.
|
Ollama vendors [llama.cpp](https://github.com/ggerganov/llama.cpp/) and [ggml](https://github.com/ggerganov/llama.cpp/tree/master/ggml/src). While we generally strive to contribute changes back upstream to avoid drift, we carry a small set of patches which are applied to the tracking commit.
|
||||||
|
|
||||||
If you update the vendoring code, start by running the following command to establish the tracking llama.cpp repo in the `./vendor/` directory.
|
If you update the vendoring code, start by running the following command to establish the tracking llama.cpp repo in the `./vendor/` directory.
|
||||||
|
|
||||||
```
|
```shell
|
||||||
make apply-patches
|
make -f Makefile.sync apply-patches
|
||||||
```
|
```
|
||||||
|
|
||||||
### Updating Base Commit
|
### Updating Base Commit
|
||||||
|
|
||||||
**Pin to new base commit**
|
**Pin to new base commit**
|
||||||
|
|
||||||
To update to a newer base commit, select the upstream git tag or commit and update `llama/vendoring`
|
To change the base commit, update `FETCH_HEAD` in Makefile.sync.
|
||||||
|
|
||||||
#### Applying patches
|
|
||||||
|
|
||||||
When updating to a newer base commit, the existing patches may not apply cleanly and require manual merge resolution.
|
When updating to a newer base commit, the existing patches may not apply cleanly and require manual merge resolution.
|
||||||
|
|
||||||
Start by applying the patches. If any of the patches have conflicts, the `git am` will stop at the first failure.
|
Start by applying the patches. If any of the patches have conflicts, the `git am` will stop at the first failure.
|
||||||
|
|
||||||
```
|
```shell
|
||||||
make apply-patches
|
make -f Makefile.sync apply-patches
|
||||||
```
|
```
|
||||||
|
|
||||||
If you see an error message about a conflict, go into the `./vendor/` directory, and perform merge resolution using your preferred tool to the patch commit which failed. Save the file(s) and continue the patch series with `git am --continue` . If any additional patches fail, follow the same pattern until the full patch series is applied. Once finished, run a final `create-patches` and `sync` target to ensure everything is updated.
|
If there are conflicts, you will see an error message. Resolve the conflicts in `./vendor/`, and continue the patch series with `git am --continue` and rerun `make -f Makefile.sync apply-patches`. Repeat until all patches are successfully applied.
|
||||||
|
|
||||||
```
|
Once all patches are applied, commit the changes to the tracking repository.
|
||||||
make create-patches sync
|
|
||||||
```
|
|
||||||
|
|
||||||
Build and test Ollama, and make any necessary changes to the Go code based on the new base commit. Submit your PR to the Ollama repo.
|
```shell
|
||||||
|
make -f Makefile.sync format-patches sync
|
||||||
|
```
|
||||||
|
|
||||||
### Generating Patches
|
### Generating Patches
|
||||||
|
|
||||||
When working on new fixes or features that impact vendored code, use the following model. First get a clean tracking repo with all current patches applied:
|
When working on new fixes or features that impact vendored code, use the following model. First get a clean tracking repo with all current patches applied:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
make -f Makefile.sync clean apply-patches
|
||||||
```
|
```
|
||||||
make apply-patches
|
|
||||||
```
|
|
||||||
|
|
||||||
Now edit the upstream native code in the `./vendor/` directory. You do not need to commit every change in order to build, a dirty working tree in the tracking repo is OK while developing. Simply save in your editor, and run the following to refresh the vendored code with your changes, build the backend(s) and build ollama:
|
|
||||||
|
|
||||||
```
|
|
||||||
make sync
|
|
||||||
make -j 8
|
|
||||||
go build .
|
|
||||||
```
|
|
||||||
|
|
||||||
> [!IMPORTANT]
|
|
||||||
> Do **NOT** run `apply-patches` while you're iterating as that will reset the tracking repo. It will detect a dirty tree and abort, but if your tree is clean and you accidentally ran this target, use `git reflog` to recover your commit(s).
|
|
||||||
|
|
||||||
Iterate until you're ready to submit PRs. Once your code is ready, commit a change in the `./vendor/` directory, then generate the patches for ollama with
|
Iterate until you're ready to submit PRs. Once your code is ready, commit a change in the `./vendor/` directory, then generate the patches for ollama with
|
||||||
|
|
||||||
|
```shell
|
||||||
|
make -f Makefile.sync format-patches
|
||||||
```
|
```
|
||||||
make create-patches
|
|
||||||
```
|
|
||||||
|
|
||||||
> [!IMPORTANT]
|
|
||||||
> Once you have completed this step, it is safe to run `apply-patches` since your change is preserved in the patches.
|
|
||||||
|
|
||||||
In your `./vendor/` directory, create a branch, and cherry-pick the new commit to that branch, then submit a PR upstream to llama.cpp.
|
In your `./vendor/` directory, create a branch, and cherry-pick the new commit to that branch, then submit a PR upstream to llama.cpp.
|
||||||
|
|
||||||
|
|||||||
34
llama/amx.h
vendored
34
llama/amx.h
vendored
@@ -1,34 +0,0 @@
|
|||||||
/**
|
|
||||||
* llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
|
|
||||||
*
|
|
||||||
* MIT License
|
|
||||||
*
|
|
||||||
* Copyright (c) 2023-2024 The ggml authors
|
|
||||||
*
|
|
||||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
* of this software and associated documentation files (the "Software"), to deal
|
|
||||||
* in the Software without restriction, including without limitation the rights
|
|
||||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
* copies of the Software, and to permit persons to whom the Software is
|
|
||||||
* furnished to do so, subject to the following conditions:
|
|
||||||
*
|
|
||||||
* The above copyright notice and this permission notice shall be included in all
|
|
||||||
* copies or substantial portions of the Software.
|
|
||||||
*
|
|
||||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
* SOFTWARE.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "ggml-backend.h"
|
|
||||||
#include "ggml-cpu-impl.h"
|
|
||||||
|
|
||||||
// GGML internal header
|
|
||||||
|
|
||||||
#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
|
|
||||||
ggml_backend_buffer_type_t ggml_backend_amx_buffer_type(void);
|
|
||||||
#endif
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user