Compare commits
196 Commits
brucemacd/
...
main
Author | SHA1 | Date | |
---|---|---|---|
![]() |
0aa8b371dd | ||
![]() |
23125648b8 | ||
![]() |
0478d440f0 | ||
![]() |
8cc33f4c2b | ||
![]() |
f46df4e5d2 | ||
![]() |
c6bcdc4223 | ||
![]() |
4b903f088a | ||
![]() |
c7f4ae7b9c | ||
![]() |
526b2ed102 | ||
![]() |
a7240c6d63 | ||
![]() |
9d6df90805 | ||
![]() |
0cefd46f23 | ||
![]() |
ad035ad595 | ||
![]() |
f95a1f2bef | ||
![]() |
82a9e9462a | ||
![]() |
76724e2f29 | ||
![]() |
ecf14a220f | ||
![]() |
69ce44b33c | ||
![]() |
5969674cf1 | ||
![]() |
867d75b21e | ||
![]() |
3fa78598a1 | ||
![]() |
0d6e35d3c6 | ||
![]() |
6e9a7a2568 | ||
![]() |
b585a58121 | ||
![]() |
fa9973cd7f | ||
![]() |
3d9498a425 | ||
![]() |
3098c8b29b | ||
![]() |
5e380c3b42 | ||
![]() |
392de84031 | ||
![]() |
af31ccefc0 | ||
![]() |
fa393554b9 | ||
![]() |
307e3b3e1d | ||
![]() |
4090aca97b | ||
![]() |
92ce438de0 | ||
![]() |
424810450f | ||
![]() |
95e744beeb | ||
![]() |
3b2d2c8326 | ||
![]() |
d931ee8f22 | ||
![]() |
7073600797 | ||
![]() |
b1c40138da | ||
![]() |
17466217e5 | ||
![]() |
1703d1472e | ||
![]() |
913905028b | ||
![]() |
7e5c8eee5c | ||
![]() |
6a74bba7e7 | ||
![]() |
76ea735aaf | ||
![]() |
dd1d4e99e7 | ||
![]() |
a6ef73f4f2 | ||
![]() |
c2f5d6662b | ||
![]() |
57fb759f3c | ||
![]() |
8dd12c873d | ||
![]() |
e6d2d04121 | ||
![]() |
074bac8447 | ||
![]() |
8e8f2c6d67 | ||
![]() |
938e8447e8 | ||
![]() |
d5d5f0c445 | ||
![]() |
a7835c6716 | ||
![]() |
ad3c7c9bda | ||
![]() |
415c8fcc3d | ||
![]() |
718eda1b3e | ||
![]() |
421b7edeb4 | ||
![]() |
7b68e254c2 | ||
![]() |
7bec2724a5 | ||
![]() |
a27462b708 | ||
![]() |
6bf0b8193a | ||
![]() |
db428adbb8 | ||
![]() |
fe5b9bb21b | ||
![]() |
6ec71d8fb6 | ||
![]() |
44b466eeb2 | ||
![]() |
a25f3f8260 | ||
![]() |
dd93e1af85 | ||
![]() |
5cfc1c39f3 | ||
![]() |
f0ad49ea17 | ||
![]() |
7ba9fa9c7d | ||
![]() |
8bf11b84c1 | ||
![]() |
470af8ab89 | ||
![]() |
178761aef3 | ||
![]() |
f0c66e6dea | ||
![]() |
54055a6dae | ||
![]() |
340448d2d1 | ||
![]() |
ced7d0e53d | ||
![]() |
a0dba0f8ae | ||
![]() |
5e20b170a7 | ||
![]() |
d26c18e25c | ||
![]() |
8d376acc9b | ||
![]() |
dc1e81f027 | ||
![]() |
5d0279164c | ||
![]() |
214a7678ea | ||
![]() |
4892872c18 | ||
![]() |
0b9198bf47 | ||
![]() |
e9e5f61c45 | ||
![]() |
11dde41824 | ||
![]() |
a53d744b01 | ||
![]() |
40b10eee6d | ||
![]() |
424f648632 | ||
![]() |
2eb1fb3231 | ||
![]() |
0806521642 | ||
![]() |
88738b357b | ||
![]() |
4e535e6188 | ||
![]() |
40b8fdbdca | ||
![]() |
1d99451ad7 | ||
![]() |
09bb2e30f6 | ||
![]() |
dc264be6ff | ||
![]() |
fbe7039618 | ||
![]() |
943464ccb8 | ||
![]() |
369de832cd | ||
![]() |
3457a315b2 | ||
![]() |
ed4e139314 | ||
![]() |
56dc316a57 | ||
![]() |
2fec73eef6 | ||
![]() |
1e7f62cb42 | ||
![]() |
ccb7eb8135 | ||
![]() |
637fd21230 | ||
![]() |
0fe487e732 | ||
![]() |
6bfaa6e282 | ||
![]() |
378d3210dc | ||
![]() |
97fe45e36d | ||
![]() |
64a9cc8f05 | ||
![]() |
f50d691254 | ||
![]() |
34c3b68fc8 | ||
![]() |
f33ccd5d27 | ||
![]() |
bc108b9ad6 | ||
![]() |
ef65174df2 | ||
![]() |
42ecb9f138 | ||
![]() |
5c0331fd83 | ||
![]() |
e7019c9455 | ||
![]() |
d98bfe7e70 | ||
![]() |
6747099d71 | ||
![]() |
ccc8c6777b | ||
![]() |
dbb149e6f7 | ||
![]() |
a807985e59 | ||
![]() |
8643c4d5bf | ||
![]() |
b0c3aba590 | ||
![]() |
19c0c25de8 | ||
![]() |
2f723ac2d6 | ||
![]() |
249fbbe52f | ||
![]() |
c38680b8a1 | ||
![]() |
16fca86c4a | ||
![]() |
0f3f9e353d | ||
![]() |
6bd0a983cd | ||
![]() |
1861fbdeb5 | ||
![]() |
3b96a93672 | ||
![]() |
e53b3cbd0c | ||
![]() |
b51e0f397c | ||
![]() |
b42970063d | ||
![]() |
493385eb3e | ||
![]() |
9876c9faa4 | ||
![]() |
4e415029b3 | ||
![]() |
e172f095ba | ||
![]() |
c001b98087 | ||
![]() |
23fc8e92eb | ||
![]() |
4059a297a6 | ||
![]() |
66b2539238 | ||
![]() |
ef27d52e79 | ||
![]() |
b2a465296d | ||
![]() |
5d097277ef | ||
![]() |
071a9872cb | ||
![]() |
0bd0454ea7 | ||
![]() |
01aa788722 | ||
![]() |
ead27aa9fe | ||
![]() |
b816ff86c9 | ||
![]() |
e5d84fb90b | ||
![]() |
dd66712e31 | ||
![]() |
f66216e399 | ||
![]() |
f4f0992b6e | ||
![]() |
1feff61977 | ||
![]() |
5e0b904e88 | ||
![]() |
131f0355a5 | ||
![]() |
ce929984a3 | ||
![]() |
4b34930a31 | ||
![]() |
74bd09652d | ||
![]() |
fb6252d786 | ||
![]() |
c794fef2f2 | ||
![]() |
00ebda8cc4 | ||
![]() |
d14ce75b95 | ||
![]() |
2d6eac9084 | ||
![]() |
3ed7ad3ab3 | ||
![]() |
6d1103048e | ||
![]() |
0ff28758b3 | ||
![]() |
d3e9ca3eda | ||
![]() |
0fbfcf3c9c | ||
![]() |
0c220935bd | ||
![]() |
ffbfe833da | ||
![]() |
42a14f7f63 | ||
![]() |
f8c3dbe5b5 | ||
![]() |
b078dd157c | ||
![]() |
2ddacd7516 | ||
![]() |
da0e345200 | ||
![]() |
df94175a0f | ||
![]() |
61a8825216 | ||
![]() |
021dcf089d | ||
![]() |
bf24498b1e | ||
![]() |
95e271d98f | ||
![]() |
364629b8d6 | ||
![]() |
108fe02165 | ||
![]() |
4561fff36e |
16
.github/workflows/release.yaml
vendored
16
.github/workflows/release.yaml
vendored
@ -432,6 +432,22 @@ jobs:
|
||||
docker buildx imagetools inspect ollama/ollama:${{ steps.metadata.outputs.version }}
|
||||
working-directory: ${{ runner.temp }}
|
||||
|
||||
# Trigger downstream release process
|
||||
trigger:
|
||||
runs-on: ubuntu-latest
|
||||
environment: release
|
||||
needs: [darwin-build, windows-build, windows-depends]
|
||||
steps:
|
||||
- name: Trigger downstream release process
|
||||
run: |
|
||||
curl -L \
|
||||
-X POST \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Authorization: Bearer ${{ secrets.RELEASE_TOKEN }}" \
|
||||
-H "X-GitHub-Api-Version: 2022-11-28" \
|
||||
https://api.github.com/repos/ollama/${{ vars.RELEASE_REPO }}/dispatches \
|
||||
-d "{\"event_type\": \"trigger-workflow\", \"client_payload\": {\"run_id\": \"${GITHUB_RUN_ID}\", \"version\": \"${GITHUB_REF_NAME#v}\"}}"
|
||||
|
||||
# Aggregate all the assets and ship a release
|
||||
release:
|
||||
needs: [darwin-sign, windows-sign, linux-build]
|
||||
|
4
.github/workflows/test.yaml
vendored
4
.github/workflows/test.yaml
vendored
@ -237,5 +237,5 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Verify patches apply cleanly and do not change files
|
||||
run: |
|
||||
make -f Makefile.sync clean sync
|
||||
git diff --compact-summary --exit-code
|
||||
make -f Makefile.sync clean checkout apply-patches sync
|
||||
git diff --compact-summary --exit-code
|
@ -19,8 +19,8 @@ linters:
|
||||
- nolintlint
|
||||
- nosprintfhostport
|
||||
- staticcheck
|
||||
- tenv
|
||||
- unconvert
|
||||
- usetesting
|
||||
- wastedassign
|
||||
- whitespace
|
||||
disable:
|
||||
|
@ -24,6 +24,7 @@ set(GGML_LLAMAFILE ON)
|
||||
set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128)
|
||||
set(GGML_CUDA_GRAPHS ON)
|
||||
set(GGML_CUDA_FA ON)
|
||||
set(GGML_CUDA_COMPRESSION_MODE default)
|
||||
|
||||
if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
|
||||
OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_SYSTEM_PROCESSOR MATCHES "arm|aarch64|ARM64|ARMv[0-9]+"))
|
||||
@ -86,9 +87,9 @@ if(CMAKE_CUDA_COMPILER)
|
||||
)
|
||||
endif()
|
||||
|
||||
set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX "^gfx(906|908|90a):xnack[+-]$"
|
||||
set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX "^gfx(906|908|90a|1200|1201):xnack[+-]$"
|
||||
CACHE STRING
|
||||
"Regular expression describing AMDGPU_TARGETS not supported on Windows. Override to force building these targets. Default \"^gfx(906|908|90a):xnack[+-]$\"."
|
||||
"Regular expression describing AMDGPU_TARGETS not supported on Windows. Override to force building these targets. Default \"^gfx(906|908|90a|1200|1201):xnack[+-]$\"."
|
||||
)
|
||||
|
||||
check_language(HIP)
|
||||
@ -97,7 +98,7 @@ if(CMAKE_HIP_COMPILER)
|
||||
|
||||
find_package(hip REQUIRED)
|
||||
if(NOT AMDGPU_TARGETS)
|
||||
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012])$")
|
||||
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012]|120[01])$")
|
||||
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
|
||||
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
|
||||
endif()
|
||||
|
@ -21,14 +21,16 @@
|
||||
"name": "CUDA 11",
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;70;75;80;86"
|
||||
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;70;75;80;86",
|
||||
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "CUDA 12",
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120"
|
||||
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120",
|
||||
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
|
||||
}
|
||||
},
|
||||
{
|
||||
@ -56,7 +58,7 @@
|
||||
"name": "ROCm 6",
|
||||
"inherits": [ "ROCm" ],
|
||||
"cacheVariables": {
|
||||
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
||||
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
||||
}
|
||||
}
|
||||
],
|
||||
|
@ -51,7 +51,7 @@ see if the change were accepted.
|
||||
|
||||
The title should look like:
|
||||
|
||||
<package>: <short description>
|
||||
<package>: <short description>
|
||||
|
||||
The package is the most affected Go package. If the change does not affect Go
|
||||
code, then use the directory name instead. Changes to a single well-known
|
||||
|
@ -104,8 +104,8 @@ COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
|
||||
FROM --platform=linux/arm64 scratch AS arm64
|
||||
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
|
||||
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
|
||||
COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 lib/ollama/cuda_jetpack5
|
||||
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 lib/ollama/cuda_jetpack6
|
||||
COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_jetpack5
|
||||
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_jetpack6
|
||||
|
||||
FROM scratch AS rocm
|
||||
COPY --from=rocm-6 dist/lib/ollama/rocm /lib/ollama/rocm
|
||||
|
@ -1,6 +1,6 @@
|
||||
UPSTREAM=https://github.com/ggerganov/llama.cpp.git
|
||||
WORKDIR=llama/vendor
|
||||
FETCH_HEAD=d7cfe1ffe0f435d0048a6058d529daf76e072d9c
|
||||
FETCH_HEAD=de4c07f93783a1a96456a44dc16b9db538ee1618
|
||||
|
||||
.PHONY: help
|
||||
help:
|
||||
@ -15,27 +15,30 @@ help:
|
||||
@echo " make -f $(lastword $(MAKEFILE_LIST)) clean sync"
|
||||
|
||||
.PHONY: sync
|
||||
sync: llama/build-info.cpp llama/llama.cpp ml/backend/ggml/ggml apply-patches
|
||||
sync: llama/build-info.cpp ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
|
||||
|
||||
.PHONY: llama/build-info.cpp
|
||||
llama/build-info.cpp: llama/build-info.cpp.in
|
||||
sed -e 's|@FETCH_HEAD@|$(FETCH_HEAD)|' $< > $@
|
||||
llama/build-info.cpp: llama/build-info.cpp.in llama/llama.cpp
|
||||
sed -e 's|@FETCH_HEAD@|$(FETCH_HEAD)|' <$< >$@
|
||||
|
||||
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal: ml/backend/ggml/ggml
|
||||
go generate ./$(@D)
|
||||
|
||||
.PHONY: llama/llama.cpp
|
||||
llama/llama.cpp: llama/vendor/ apply-patches
|
||||
llama/llama.cpp: llama/vendor/
|
||||
rsync -arvzc -f "merge $@/.rsync-filter" $< $@
|
||||
|
||||
.PHONY: ml/backend/ggml/ggml apply-patches
|
||||
ml/backend/ggml/ggml: llama/vendor/ggml/ apply-patches
|
||||
.PHONY: ml/backend/ggml/ggml
|
||||
ml/backend/ggml/ggml: llama/vendor/ggml/
|
||||
rsync -arvzc -f "merge $@/.rsync-filter" $< $@
|
||||
|
||||
PATCHES=$(wildcard llama/patches/*.patch)
|
||||
PATCHED=$(join $(dir $(PATCHES)), $(addsuffix ed, $(addprefix ., $(notdir $(PATCHES)))))
|
||||
|
||||
.PHONY: apply-patches
|
||||
.NOTPARALLEL:
|
||||
apply-patches: $(addsuffix ed, $(PATCHES))
|
||||
apply-patches: $(PATCHED)
|
||||
|
||||
%.patched: %.patch
|
||||
llama/patches/.%.patched: llama/patches/%.patch
|
||||
@if git -c user.name=nobody -c 'user.email=<>' -C $(WORKDIR) am -3 $(realpath $<); then touch $@; else git -C $(WORKDIR) am --abort; exit 1; fi
|
||||
|
||||
.PHONY: checkout
|
||||
@ -57,4 +60,4 @@ format-patches: llama/patches
|
||||
|
||||
.PHONE: clean
|
||||
clean: checkout
|
||||
$(RM) $(addsuffix ed, $(PATCHES))
|
||||
$(RM) llama/patches/.*.patched
|
||||
|
53
README.md
53
README.md
@ -61,6 +61,8 @@ Here are some example models that can be downloaded:
|
||||
| QwQ | 32B | 20GB | `ollama run qwq` |
|
||||
| DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` |
|
||||
| DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` |
|
||||
| Llama 4 | 109B | 67GB | `ollama run llama4:scout` |
|
||||
| Llama 4 | 400B | 245GB | `ollama run llama4:maverick` |
|
||||
| Llama 3.3 | 70B | 43GB | `ollama run llama3.3` |
|
||||
| Llama 3.2 | 3B | 2.0GB | `ollama run llama3.2` |
|
||||
| Llama 3.2 | 1B | 1.3GB | `ollama run llama3.2:1b` |
|
||||
@ -77,7 +79,7 @@ Here are some example models that can be downloaded:
|
||||
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
||||
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
||||
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
||||
| Granite-3.2 | 8B | 4.9GB | `ollama run granite3.2` |
|
||||
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
|
||||
|
||||
> [!NOTE]
|
||||
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
||||
@ -285,12 +287,13 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
||||
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
||||
- [Saddle](https://github.com/jikkuatwork/saddle)
|
||||
- [TagSpaces](https://www.tagspaces.org) (A platform for file-based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
|
||||
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
||||
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
|
||||
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
|
||||
- [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
|
||||
- [Ollamac](https://github.com/kevinhermawan/Ollamac)
|
||||
- [big-AGI](https://github.com/enricoros/big-AGI/blob/main/docs/config-local-ollama.md)
|
||||
- [big-AGI](https://github.com/enricoros/big-AGI)
|
||||
- [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
|
||||
- [Amica](https://github.com/semperai/amica)
|
||||
- [chatd](https://github.com/BruceMacD/chatd)
|
||||
@ -311,6 +314,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
|
||||
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
|
||||
- [IntelliBar](https://intellibar.app/) (AI-powered assistant for macOS)
|
||||
- [Jirapt](https://github.com/AliAhmedNada/jirapt) (Jira Integration to generate issues, tasks, epics)
|
||||
- [ojira](https://github.com/AliAhmedNada/ojira) (Jira chrome plugin to easily generate descriptions for tasks)
|
||||
- [QA-Pilot](https://github.com/reid41/QA-Pilot) (Interactive chat tool that can leverage Ollama models for rapid understanding and navigation of GitHub code repositories)
|
||||
- [ChatOllama](https://github.com/sugarforever/chat-ollama) (Open Source Chatbot based on Ollama with Knowledge Bases)
|
||||
- [CRAG Ollama Chat](https://github.com/Nagi-ovo/CRAG-Ollama-Chat) (Simple Web Search with Corrective RAG)
|
||||
@ -324,13 +329,14 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama)
|
||||
- [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models)
|
||||
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
|
||||
- [Casibase](https://casibase.org) (An open source AI knowledge base and dialogue system combining the latest RAG, SSO, ollama support, and multiple large language models.)
|
||||
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
|
||||
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
|
||||
- [Shinkai Desktop](https://github.com/dcSpark/shinkai-apps) (Two click install Local AI using Ollama + Files + RAG)
|
||||
- [AiLama](https://github.com/zeyoyt/ailama) (A Discord User App that allows you to interact with Ollama anywhere in discord )
|
||||
- [AiLama](https://github.com/zeyoyt/ailama) (A Discord User App that allows you to interact with Ollama anywhere in Discord)
|
||||
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
|
||||
- [R2R](https://github.com/SciPhi-AI/R2R) (Open-source RAG engine)
|
||||
- [Ollama-Kis](https://github.com/elearningshow/ollama-kis) (A simple easy to use GUI with sample custom LLM for Drivers Education)
|
||||
- [Ollama-Kis](https://github.com/elearningshow/ollama-kis) (A simple easy-to-use GUI with sample custom LLM for Drivers Education)
|
||||
- [OpenGPA](https://opengpa.org) (Open-source offline-first Enterprise Agentic Application)
|
||||
- [Painting Droid](https://github.com/mateuszmigas/painting-droid) (Painting app with AI integrations)
|
||||
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
|
||||
@ -339,16 +345,16 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [LLMStack](https://github.com/trypromptly/LLMStack) (No-code multi-agent framework to build LLM agents and workflows)
|
||||
- [BoltAI for Mac](https://boltai.com) (AI Chat Client for Mac)
|
||||
- [Harbor](https://github.com/av/harbor) (Containerized LLM Toolkit with Ollama as default backend)
|
||||
- [PyGPT](https://github.com/szczyglis-dev/py-gpt) (AI desktop assistant for Linux, Windows and Mac)
|
||||
- [Alpaca](https://github.com/Jeffser/Alpaca) (An Ollama client application for linux and macos made with GTK4 and Adwaita)
|
||||
- [PyGPT](https://github.com/szczyglis-dev/py-gpt) (AI desktop assistant for Linux, Windows, and Mac)
|
||||
- [Alpaca](https://github.com/Jeffser/Alpaca) (An Ollama client application for Linux and macOS made with GTK4 and Adwaita)
|
||||
- [AutoGPT](https://github.com/Significant-Gravitas/AutoGPT/blob/master/docs/content/platform/ollama.md) (AutoGPT Ollama integration)
|
||||
- [Go-CREW](https://www.jonathanhecl.com/go-crew/) (Powerful Offline RAG in Golang)
|
||||
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
|
||||
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot and Ollama4j
|
||||
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot, and Ollama4j
|
||||
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
|
||||
- [Claude Dev](https://github.com/saoudrizwan/claude-dev) - VSCode extension for multi-file/whole-repo coding
|
||||
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VSCode extension for multi-file/whole-repo coding
|
||||
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
|
||||
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
|
||||
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
|
||||
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
|
||||
- [crewAI with Mesop](https://github.com/rapidarchitect/ollama-crew-mesop) (Mesop Web Interface to run crewAI with Ollama)
|
||||
- [Tkinter-based client](https://github.com/chyok/ollama-gui) (Python tkinter-based Client for Ollama)
|
||||
@ -366,7 +372,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
|
||||
- [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol)
|
||||
- [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app)
|
||||
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard and said in the meetings)
|
||||
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard, and said in the meetings)
|
||||
- [Hexabot](https://github.com/hexastack/hexabot) (A conversational AI builder)
|
||||
- [Reddit Rate](https://github.com/rapidarchitect/reddit_analyzer) (Search and Rate Reddit topics with a weighted summation)
|
||||
- [OpenTalkGpt](https://github.com/adarshM84/OpenTalkGpt) (Chrome Extension to manage open-source models supported by Ollama, create custom models, and chat with models from a user-friendly UI)
|
||||
@ -384,7 +390,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [ChibiChat](https://github.com/CosmicEventHorizon/ChibiChat) (Kotlin-based Android app to chat with Ollama and Koboldcpp API endpoints)
|
||||
- [LocalLLM](https://github.com/qusaismael/localllm) (Minimal Web-App to run ollama models on it with a GUI)
|
||||
- [Ollamazing](https://github.com/buiducnhat/ollamazing) (Web extension to run Ollama models)
|
||||
- [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivent endpoint with Ollama support for running locally)
|
||||
- [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivalent endpoint with Ollama support for running locally)
|
||||
- [AntSK](https://github.com/AIDotNet/AntSK) (Out-of-the-box & Adaptable RAG Chatbot)
|
||||
- [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot)
|
||||
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
|
||||
@ -392,8 +398,13 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
|
||||
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
|
||||
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
||||
- [Flufy](https://github.com/Aharon-Bensadoun/Flufy) (A beautiful chat interface for interacting with Ollama's API. Built with React, TypeScript, and Material-UI.)
|
||||
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
|
||||
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
|
||||
- [Ollamb](https://github.com/hengkysteen/ollamb) (Simple yet rich in features, cross-platform built with Flutter and designed for Ollama. Try the [web demo](https://hengkysteen.github.io/demo/ollamb/).)
|
||||
- [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama)
|
||||
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
|
||||
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
|
||||
|
||||
### Cloud
|
||||
|
||||
@ -433,7 +444,10 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [SwollamaCLI](https://github.com/marcusziade/Swollama) bundled with the Swollama Swift package. [Demo](https://github.com/marcusziade/Swollama?tab=readme-ov-file#cli-usage)
|
||||
- [aichat](https://github.com/sigoden/aichat) All-in-one LLM CLI tool featuring Shell Assistant, Chat-REPL, RAG, AI tools & agents, with access to OpenAI, Claude, Gemini, Ollama, Groq, and more.
|
||||
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
|
||||
- [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis.
|
||||
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
|
||||
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull, and download models from Ollama Registry in your terminal.
|
||||
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
|
||||
|
||||
### Apple Vision Pro
|
||||
|
||||
@ -460,7 +474,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
|
||||
### Libraries
|
||||
|
||||
- [LangChain](https://python.langchain.com/docs/integrations/llms/ollama) and [LangChain.js](https://js.langchain.com/docs/integrations/chat/ollama/) with [example](https://js.langchain.com/docs/tutorials/local_rag/)
|
||||
- [LangChain](https://python.langchain.com/docs/integrations/chat/ollama/) and [LangChain.js](https://js.langchain.com/docs/integrations/chat/ollama/) with [example](https://js.langchain.com/docs/tutorials/local_rag/)
|
||||
- [Firebase Genkit](https://firebase.google.com/docs/genkit/plugins/ollama)
|
||||
- [crewAI](https://github.com/crewAIInc/crewAI)
|
||||
- [Yacana](https://remembersoftwares.github.io/yacana/) (User-friendly multi-agent framework for brainstorming and executing predetermined flows with built-in tool integration)
|
||||
@ -507,19 +521,21 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Swollama for Swift](https://github.com/marcusziade/Swollama) with [DocC](https://marcusziade.github.io/Swollama/documentation/swollama/)
|
||||
- [GoLamify](https://github.com/prasad89/golamify)
|
||||
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
|
||||
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in unified API)
|
||||
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API)
|
||||
- [LlmTornado](https://github.com/lofcz/llmtornado) (C# library providing a unified interface for major FOSS & Commercial inference APIs)
|
||||
- [Ollama for Zig](https://github.com/dravenk/ollama-zig)
|
||||
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
|
||||
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
|
||||
- [Ollama for D](https://github.com/kassane/ollama-d)
|
||||
- [OllamaPlusPlus](https://github.com/HardCodeDev777/OllamaPlusPlus) (Very simple C++ library for Ollama)
|
||||
|
||||
### Mobile
|
||||
|
||||
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Lightning-fast Cross-platform AI chat app with native UI for Android, iOS and iPad)
|
||||
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Lightning-fast Cross-platform AI chat app with native UI for Android, iOS, and iPad)
|
||||
- [Enchanted](https://github.com/AugustDev/enchanted)
|
||||
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid)
|
||||
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
||||
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
|
||||
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
|
||||
- [Ollama Android Chat](https://github.com/sunshine0523/OllamaServer) (No need for Termux, start the Ollama service with one click on an Android device)
|
||||
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
||||
|
||||
@ -543,7 +559,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Obsidian Local GPT plugin](https://github.com/pfrankov/obsidian-local-gpt)
|
||||
- [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama)
|
||||
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
|
||||
- [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use ollama as a copilot like Github copilot)
|
||||
- [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use Ollama as a copilot like GitHub Copilot)
|
||||
- [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama)
|
||||
- [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and Hugging Face)
|
||||
- [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension)
|
||||
@ -553,8 +569,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation)
|
||||
- [ChatGPTBox: All in one browser extension](https://github.com/josStorer/chatGPTBox) with [Integrating Tutorial](https://github.com/josStorer/chatGPTBox/issues/616#issuecomment-1975186467)
|
||||
- [Discord AI chat/moderation bot](https://github.com/rapmd73/Companion) Chat/moderation bot written in python. Uses Ollama to create personalities.
|
||||
- [Headless Ollama](https://github.com/nischalj10/headless-ollama) (Scripts to automatically install ollama client & models on any OS for apps that depends on ollama server)
|
||||
- [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front end Open WebUI service.)
|
||||
- [Headless Ollama](https://github.com/nischalj10/headless-ollama) (Scripts to automatically install ollama client & models on any OS for apps that depend on ollama server)
|
||||
- [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front-end Open WebUI service.)
|
||||
- [node-red-contrib-ollama](https://github.com/jakubburkiewicz/node-red-contrib-ollama)
|
||||
- [Local AI Helper](https://github.com/ivostoykov/localAI) (Chrome and Firefox extensions that enable interactions with the active tab and customisable API endpoints. Includes secure storage for user prompts.)
|
||||
- [vnc-lm](https://github.com/jake83741/vnc-lm) (Discord bot for messaging with LLMs through Ollama and LiteLLM. Seamlessly move between local and flagship models.)
|
||||
@ -568,6 +584,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai)
|
||||
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
|
||||
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
|
||||
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
|
||||
|
||||
### Supported backends
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@ -137,7 +136,7 @@ func TestClientStream(t *testing.T) {
|
||||
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
|
||||
|
||||
var receivedChunks []ChatResponse
|
||||
err := client.stream(context.Background(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error {
|
||||
err := client.stream(t.Context(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error {
|
||||
var resp ChatResponse
|
||||
if err := json.Unmarshal(chunk, &resp); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal chunk: %w", err)
|
||||
@ -223,7 +222,7 @@ func TestClientDo(t *testing.T) {
|
||||
ID string `json:"id"`
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
err := client.do(context.Background(), http.MethodPost, "/v1/messages", nil, &resp)
|
||||
err := client.do(t.Context(), http.MethodPost, "/v1/messages", nil, &resp)
|
||||
|
||||
if tc.wantErr != "" {
|
||||
if err == nil {
|
||||
|
120
api/types.go
120
api/types.go
@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// StatusError is an error with an HTTP status code and message.
|
||||
@ -75,13 +76,13 @@ type GenerateRequest struct {
|
||||
// this request.
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
|
||||
// Images is an optional list of base64-encoded images accompanying this
|
||||
// Images is an optional list of raw image bytes accompanying this
|
||||
// request, for multimodal models.
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
|
||||
// Options lists model-specific options. For example, temperature can be
|
||||
// set through this field, if the model supports it.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
Options map[string]any `json:"options"`
|
||||
}
|
||||
|
||||
// ChatRequest describes a request sent by [Client.Chat].
|
||||
@ -106,7 +107,7 @@ type ChatRequest struct {
|
||||
Tools `json:"tools,omitempty"`
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
Options map[string]any `json:"options"`
|
||||
}
|
||||
|
||||
type Tools []Tool
|
||||
@ -162,19 +163,65 @@ func (t *ToolCallFunctionArguments) String() string {
|
||||
|
||||
type Tool struct {
|
||||
Type string `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Function ToolFunction `json:"function"`
|
||||
}
|
||||
|
||||
// PropertyType can be either a string or an array of strings
|
||||
type PropertyType []string
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface
|
||||
func (pt *PropertyType) UnmarshalJSON(data []byte) error {
|
||||
// Try to unmarshal as a string first
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err == nil {
|
||||
*pt = []string{s}
|
||||
return nil
|
||||
}
|
||||
|
||||
// If that fails, try to unmarshal as an array of strings
|
||||
var a []string
|
||||
if err := json.Unmarshal(data, &a); err != nil {
|
||||
return err
|
||||
}
|
||||
*pt = a
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface
|
||||
func (pt PropertyType) MarshalJSON() ([]byte, error) {
|
||||
if len(pt) == 1 {
|
||||
// If there's only one type, marshal as a string
|
||||
return json.Marshal(pt[0])
|
||||
}
|
||||
// Otherwise marshal as an array
|
||||
return json.Marshal([]string(pt))
|
||||
}
|
||||
|
||||
// String returns a string representation of the PropertyType
|
||||
func (pt PropertyType) String() string {
|
||||
if len(pt) == 0 {
|
||||
return ""
|
||||
}
|
||||
if len(pt) == 1 {
|
||||
return pt[0]
|
||||
}
|
||||
return fmt.Sprintf("%v", []string(pt))
|
||||
}
|
||||
|
||||
type ToolFunction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
Type PropertyType `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
} `json:"properties"`
|
||||
} `json:"parameters"`
|
||||
}
|
||||
@ -224,9 +271,6 @@ type Options struct {
|
||||
RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
|
||||
PresencePenalty float32 `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
|
||||
Mirostat int `json:"mirostat,omitempty"`
|
||||
MirostatTau float32 `json:"mirostat_tau,omitempty"`
|
||||
MirostatEta float32 `json:"mirostat_eta,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
}
|
||||
|
||||
@ -236,12 +280,7 @@ type Runner struct {
|
||||
NumBatch int `json:"num_batch,omitempty"`
|
||||
NumGPU int `json:"num_gpu,omitempty"`
|
||||
MainGPU int `json:"main_gpu,omitempty"`
|
||||
LowVRAM bool `json:"low_vram,omitempty"`
|
||||
F16KV bool `json:"f16_kv,omitempty"` // Deprecated: This option is ignored
|
||||
LogitsAll bool `json:"logits_all,omitempty"`
|
||||
VocabOnly bool `json:"vocab_only,omitempty"`
|
||||
UseMMap *bool `json:"use_mmap,omitempty"`
|
||||
UseMLock bool `json:"use_mlock,omitempty"`
|
||||
NumThread int `json:"num_thread,omitempty"`
|
||||
}
|
||||
|
||||
@ -260,7 +299,7 @@ type EmbedRequest struct {
|
||||
Truncate *bool `json:"truncate,omitempty"`
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
Options map[string]any `json:"options"`
|
||||
}
|
||||
|
||||
// EmbedResponse is the response from [Client.Embed].
|
||||
@ -286,7 +325,7 @@ type EmbeddingRequest struct {
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
Options map[string]any `json:"options"`
|
||||
}
|
||||
|
||||
// EmbeddingResponse is the response from [Client.Embeddings].
|
||||
@ -332,7 +371,7 @@ type ShowRequest struct {
|
||||
Template string `json:"template"`
|
||||
Verbose bool `json:"verbose"`
|
||||
|
||||
Options map[string]interface{} `json:"options"`
|
||||
Options map[string]any `json:"options"`
|
||||
|
||||
// Deprecated: set the model name with Model instead
|
||||
Name string `json:"name"`
|
||||
@ -340,17 +379,18 @@ type ShowRequest struct {
|
||||
|
||||
// ShowResponse is the response returned from [Client.Show].
|
||||
type ShowResponse struct {
|
||||
License string `json:"license,omitempty"`
|
||||
Modelfile string `json:"modelfile,omitempty"`
|
||||
Parameters string `json:"parameters,omitempty"`
|
||||
Template string `json:"template,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
ModelInfo map[string]any `json:"model_info,omitempty"`
|
||||
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
||||
Tensors []Tensor `json:"tensors,omitempty"`
|
||||
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
||||
License string `json:"license,omitempty"`
|
||||
Modelfile string `json:"modelfile,omitempty"`
|
||||
Parameters string `json:"parameters,omitempty"`
|
||||
Template string `json:"template,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
ModelInfo map[string]any `json:"model_info,omitempty"`
|
||||
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
||||
Tensors []Tensor `json:"tensors,omitempty"`
|
||||
Capabilities []model.Capability `json:"capabilities,omitempty"`
|
||||
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
||||
}
|
||||
|
||||
// CopyRequest is the request passed to [Client.Copy].
|
||||
@ -423,13 +463,6 @@ type ProcessModelResponse struct {
|
||||
SizeVRAM int64 `json:"size_vram"`
|
||||
}
|
||||
|
||||
type RetrieveModelResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
}
|
||||
|
||||
type TokenResponse struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
@ -503,7 +536,7 @@ func (m *Metrics) Summary() {
|
||||
}
|
||||
}
|
||||
|
||||
func (opts *Options) FromMap(m map[string]interface{}) error {
|
||||
func (opts *Options) FromMap(m map[string]any) error {
|
||||
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
||||
|
||||
@ -560,12 +593,12 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
|
||||
}
|
||||
field.SetString(val)
|
||||
case reflect.Slice:
|
||||
// JSON unmarshals to []interface{}, not []string
|
||||
val, ok := val.([]interface{})
|
||||
// JSON unmarshals to []any, not []string
|
||||
val, ok := val.([]any)
|
||||
if !ok {
|
||||
return fmt.Errorf("option %q must be of type array", key)
|
||||
}
|
||||
// convert []interface{} to []string
|
||||
// convert []any to []string
|
||||
slice := make([]string, len(val))
|
||||
for i, item := range val {
|
||||
str, ok := item.(string)
|
||||
@ -612,9 +645,6 @@ func DefaultOptions() Options {
|
||||
RepeatPenalty: 1.1,
|
||||
PresencePenalty: 0.0,
|
||||
FrequencyPenalty: 0.0,
|
||||
Mirostat: 0,
|
||||
MirostatTau: 5.0,
|
||||
MirostatEta: 0.1,
|
||||
Seed: -1,
|
||||
|
||||
Runner: Runner{
|
||||
@ -623,8 +653,6 @@ func DefaultOptions() Options {
|
||||
NumBatch: 512,
|
||||
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
|
||||
NumThread: 0, // let the runtime decide
|
||||
LowVRAM: false,
|
||||
UseMLock: false,
|
||||
UseMMap: nil,
|
||||
},
|
||||
}
|
||||
@ -672,7 +700,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||
}
|
||||
|
||||
// FormatParams converts specified parameter options to their correct types
|
||||
func FormatParams(params map[string][]string) (map[string]interface{}, error) {
|
||||
func FormatParams(params map[string][]string) (map[string]any, error) {
|
||||
opts := Options{}
|
||||
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
|
||||
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
|
||||
@ -686,7 +714,7 @@ func FormatParams(params map[string][]string) (map[string]interface{}, error) {
|
||||
}
|
||||
}
|
||||
|
||||
out := make(map[string]interface{})
|
||||
out := make(map[string]any)
|
||||
// iterate params and set values based on json struct tags
|
||||
for key, vals := range params {
|
||||
if opt, ok := jsonOpts[key]; !ok {
|
||||
|
@ -134,7 +134,7 @@ func TestUseMmapParsingFromJSON(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var oMap map[string]interface{}
|
||||
var oMap map[string]any
|
||||
err := json.Unmarshal([]byte(test.req), &oMap)
|
||||
require.NoError(t, err)
|
||||
opts := DefaultOptions()
|
||||
@ -231,3 +231,144 @@ func TestMessage_UnmarshalJSON(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolFunction_UnmarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid enum with same types",
|
||||
input: `{
|
||||
"name": "test",
|
||||
"description": "test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": ["test"],
|
||||
"properties": {
|
||||
"test": {
|
||||
"type": "string",
|
||||
"description": "test prop",
|
||||
"enum": ["a", "b", "c"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "empty enum array",
|
||||
input: `{
|
||||
"name": "test",
|
||||
"description": "test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": ["test"],
|
||||
"properties": {
|
||||
"test": {
|
||||
"type": "string",
|
||||
"description": "test prop",
|
||||
"enum": []
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
wantErr: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var tf ToolFunction
|
||||
err := json.Unmarshal([]byte(tt.input), &tf)
|
||||
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.wantErr)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPropertyType_UnmarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected PropertyType
|
||||
}{
|
||||
{
|
||||
name: "string type",
|
||||
input: `"string"`,
|
||||
expected: PropertyType{"string"},
|
||||
},
|
||||
{
|
||||
name: "array of types",
|
||||
input: `["string", "number"]`,
|
||||
expected: PropertyType{"string", "number"},
|
||||
},
|
||||
{
|
||||
name: "array with single type",
|
||||
input: `["string"]`,
|
||||
expected: PropertyType{"string"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var pt PropertyType
|
||||
if err := json.Unmarshal([]byte(test.input), &pt); err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(pt) != len(test.expected) {
|
||||
t.Errorf("Length mismatch: got %v, expected %v", len(pt), len(test.expected))
|
||||
}
|
||||
|
||||
for i, v := range pt {
|
||||
if v != test.expected[i] {
|
||||
t.Errorf("Value mismatch at index %d: got %v, expected %v", i, v, test.expected[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPropertyType_MarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input PropertyType
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "single type",
|
||||
input: PropertyType{"string"},
|
||||
expected: `"string"`,
|
||||
},
|
||||
{
|
||||
name: "multiple types",
|
||||
input: PropertyType{"string", "number"},
|
||||
expected: `["string","number"]`,
|
||||
},
|
||||
{
|
||||
name: "empty type",
|
||||
input: PropertyType{},
|
||||
expected: `[]`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(test.input)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if string(data) != test.expected {
|
||||
t.Errorf("Marshaled data mismatch: got %v, expected %v", string(data), test.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -4,20 +4,14 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
func InitLogging() {
|
||||
level := slog.LevelInfo
|
||||
|
||||
if envconfig.Debug() {
|
||||
level = slog.LevelDebug
|
||||
}
|
||||
|
||||
var logFile *os.File
|
||||
var err error
|
||||
// Detect if we're a GUI app on windows, and if not, send logs to console
|
||||
@ -33,20 +27,8 @@ func InitLogging() {
|
||||
return
|
||||
}
|
||||
}
|
||||
handler := slog.NewTextHandler(logFile, &slog.HandlerOptions{
|
||||
Level: level,
|
||||
AddSource: true,
|
||||
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
|
||||
if attr.Key == slog.SourceKey {
|
||||
source := attr.Value.Any().(*slog.Source)
|
||||
source.File = filepath.Base(source.File)
|
||||
}
|
||||
return attr
|
||||
},
|
||||
})
|
||||
|
||||
slog.SetDefault(slog.New(handler))
|
||||
|
||||
slog.SetDefault(logutil.NewLogger(logFile, envconfig.LogLevel()))
|
||||
slog.Info("ollama app started")
|
||||
}
|
||||
|
||||
|
178
benchmark/server_benchmark_test.go
Normal file
178
benchmark/server_benchmark_test.go
Normal file
@ -0,0 +1,178 @@
|
||||
package benchmark
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Command line flags
|
||||
var modelFlag string
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
|
||||
flag.Lookup("m").DefValue = "model"
|
||||
}
|
||||
|
||||
// modelName returns the model name from flags, failing the test if not set
|
||||
func modelName(b *testing.B) string {
|
||||
if modelFlag == "" {
|
||||
b.Fatal("Error: -m flag is required for benchmark tests")
|
||||
}
|
||||
return modelFlag
|
||||
}
|
||||
|
||||
type TestCase struct {
|
||||
name string
|
||||
prompt string
|
||||
maxTokens int
|
||||
}
|
||||
|
||||
// runGenerateBenchmark contains the common generate and metrics logic
|
||||
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
|
||||
start := time.Now()
|
||||
var ttft time.Duration
|
||||
var metrics api.Metrics
|
||||
|
||||
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||
if ttft == 0 && resp.Response != "" {
|
||||
ttft = time.Since(start)
|
||||
}
|
||||
if resp.Done {
|
||||
metrics = resp.Metrics
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Report custom metrics as part of the benchmark results
|
||||
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
|
||||
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
|
||||
|
||||
// Token throughput metrics
|
||||
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
|
||||
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
|
||||
b.ReportMetric(promptThroughput, "prompt_tok/s")
|
||||
b.ReportMetric(genThroughput, "gen_tok/s")
|
||||
|
||||
// Token counts
|
||||
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
|
||||
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkColdStart runs benchmarks with model loading from cold state
|
||||
func BenchmarkColdStart(b *testing.B) {
|
||||
client := setup(b)
|
||||
tests := []TestCase{
|
||||
{"short_prompt", "Write a long story", 100},
|
||||
{"medium_prompt", "Write a detailed economic analysis", 500},
|
||||
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
||||
}
|
||||
m := modelName(b)
|
||||
|
||||
for _, tt := range tests {
|
||||
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
|
||||
ctx := b.Context()
|
||||
|
||||
// Set number of tokens as our throughput metric
|
||||
b.SetBytes(int64(tt.maxTokens))
|
||||
|
||||
for b.Loop() {
|
||||
b.StopTimer()
|
||||
// Ensure model is unloaded before each iteration
|
||||
unload(client, m, b)
|
||||
b.StartTimer()
|
||||
|
||||
req := &api.GenerateRequest{
|
||||
Model: m,
|
||||
Prompt: tt.prompt,
|
||||
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||
}
|
||||
|
||||
runGenerateBenchmark(b, ctx, client, req)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkWarmStart runs benchmarks with pre-loaded model
|
||||
func BenchmarkWarmStart(b *testing.B) {
|
||||
client := setup(b)
|
||||
tests := []TestCase{
|
||||
{"short_prompt", "Write a long story", 100},
|
||||
{"medium_prompt", "Write a detailed economic analysis", 500},
|
||||
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
||||
}
|
||||
m := modelName(b)
|
||||
|
||||
for _, tt := range tests {
|
||||
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
|
||||
ctx := b.Context()
|
||||
|
||||
// Pre-warm the model
|
||||
warmup(client, m, tt.prompt, b)
|
||||
|
||||
// Set number of tokens as our throughput metric
|
||||
b.SetBytes(int64(tt.maxTokens))
|
||||
|
||||
for b.Loop() {
|
||||
req := &api.GenerateRequest{
|
||||
Model: m,
|
||||
Prompt: tt.prompt,
|
||||
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||
}
|
||||
|
||||
runGenerateBenchmark(b, ctx, client, req)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// setup verifies server and model availability
|
||||
func setup(b *testing.B) *api.Client {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if _, err := client.Show(b.Context(), &api.ShowRequest{Model: modelName(b)}); err != nil {
|
||||
b.Fatalf("Model unavailable: %v", err)
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
// warmup ensures the model is loaded and warmed up
|
||||
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
|
||||
for range 3 {
|
||||
err := client.Generate(
|
||||
context.Background(),
|
||||
&api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: prompt,
|
||||
Options: map[string]any{"num_predict": 50, "temperature": 0.1},
|
||||
},
|
||||
func(api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
b.Logf("Error during model warm-up: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// unload forces model unloading using KeepAlive: 0 parameter
|
||||
func unload(client *api.Client, model string, b *testing.B) {
|
||||
req := &api.GenerateRequest{
|
||||
Model: model,
|
||||
KeepAlive: &api.Duration{Duration: 0},
|
||||
}
|
||||
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
|
||||
b.Logf("Unload error: %v", err)
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
107
cmd/cmd.go
107
cmd/cmd.go
@ -18,6 +18,7 @@ import (
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -30,6 +31,7 @@ import (
|
||||
"github.com/olekukonko/tablewriter"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
@ -40,6 +42,7 @@ import (
|
||||
"github.com/ollama/ollama/runner"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/types/syncmap"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
@ -105,7 +108,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
spinner.Stop()
|
||||
|
||||
req.Name = args[0]
|
||||
req.Model = args[0]
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
if quantize != "" {
|
||||
req.Quantize = quantize
|
||||
@ -116,34 +119,54 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(req.Files) > 0 {
|
||||
fileMap := map[string]string{}
|
||||
for f, digest := range req.Files {
|
||||
var g errgroup.Group
|
||||
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
|
||||
|
||||
files := syncmap.NewSyncMap[string, string]()
|
||||
for f, digest := range req.Files {
|
||||
g.Go(func() error {
|
||||
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
|
||||
return err
|
||||
}
|
||||
fileMap[filepath.Base(f)] = digest
|
||||
}
|
||||
req.Files = fileMap
|
||||
|
||||
// TODO: this is incorrect since the file might be in a subdirectory
|
||||
// instead this should take the path relative to the model directory
|
||||
// but the current implementation does not allow this
|
||||
files.Store(filepath.Base(f), digest)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if len(req.Adapters) > 0 {
|
||||
fileMap := map[string]string{}
|
||||
for f, digest := range req.Adapters {
|
||||
adapters := syncmap.NewSyncMap[string, string]()
|
||||
for f, digest := range req.Adapters {
|
||||
g.Go(func() error {
|
||||
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
|
||||
return err
|
||||
}
|
||||
fileMap[filepath.Base(f)] = digest
|
||||
}
|
||||
req.Adapters = fileMap
|
||||
|
||||
// TODO: same here
|
||||
adapters.Store(filepath.Base(f), digest)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Files = files.Items()
|
||||
req.Adapters = adapters.Items()
|
||||
|
||||
bars := make(map[string]*progress.Bar)
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
bar, ok := bars[resp.Digest]
|
||||
if !ok {
|
||||
bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
|
||||
msg := resp.Status
|
||||
if msg == "" {
|
||||
msg = fmt.Sprintf("pulling %s...", resp.Digest[7:19])
|
||||
}
|
||||
bar = progress.NewBar(msg, resp.Total, resp.Completed)
|
||||
bars[resp.Digest] = bar
|
||||
p.Add(resp.Digest, bar)
|
||||
}
|
||||
@ -212,7 +235,7 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string, digest stri
|
||||
}
|
||||
}()
|
||||
|
||||
if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
|
||||
if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return digest, nil
|
||||
@ -267,7 +290,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
opts := runOptions{
|
||||
Model: args[0],
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]interface{}{},
|
||||
Options: map[string]any{},
|
||||
}
|
||||
|
||||
format, err := cmd.Flags().GetString("format")
|
||||
@ -339,6 +362,11 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
|
||||
|
||||
// TODO: remove the projector info and vision info checks below,
|
||||
// these are left in for backwards compatibility with older servers
|
||||
// that don't have the capabilities field in the model info
|
||||
if len(info.ProjectorInfo) != 0 {
|
||||
opts.MultiModal = true
|
||||
}
|
||||
@ -669,6 +697,15 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
||||
return
|
||||
})
|
||||
|
||||
if len(resp.Capabilities) > 0 {
|
||||
tableRender("Capabilities", func() (rows [][]string) {
|
||||
for _, capability := range resp.Capabilities {
|
||||
rows = append(rows, []string{"", capability.String()})
|
||||
}
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
if resp.ProjectorInfo != nil {
|
||||
tableRender("Projector", func() (rows [][]string) {
|
||||
arch := resp.ProjectorInfo["general.architecture"].(string)
|
||||
@ -703,6 +740,8 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
||||
for _, k := range keys {
|
||||
var v string
|
||||
switch vData := resp.ModelInfo[k].(type) {
|
||||
case bool:
|
||||
v = fmt.Sprintf("%t", vData)
|
||||
case string:
|
||||
v = vData
|
||||
case float64:
|
||||
@ -791,13 +830,38 @@ func PullHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
if resp.Completed == 0 {
|
||||
// This is the initial status update for the
|
||||
// layer, which the server sends before
|
||||
// beginning the download, for clients to
|
||||
// compute total size and prepare for
|
||||
// downloads, if needed.
|
||||
//
|
||||
// Skipping this here to avoid showing a 0%
|
||||
// progress bar, which *should* clue the user
|
||||
// into the fact that many things are being
|
||||
// downloaded and that the current active
|
||||
// download is not that last. However, in rare
|
||||
// cases it seems to be triggering to some, and
|
||||
// it isn't worth explaining, so just ignore
|
||||
// and regress to the old UI that keeps giving
|
||||
// you the "But wait, there is more!" after
|
||||
// each "100% done" bar, which is "better."
|
||||
return nil
|
||||
}
|
||||
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
|
||||
bar, ok := bars[resp.Digest]
|
||||
if !ok {
|
||||
bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
|
||||
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
|
||||
name = strings.TrimSpace(name)
|
||||
if isDigest {
|
||||
name = name[:min(12, len(name))]
|
||||
}
|
||||
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
|
||||
bars[resp.Digest] = bar
|
||||
p.Add(resp.Digest, bar)
|
||||
}
|
||||
@ -817,11 +881,7 @@ func PullHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
request := api.PullRequest{Name: args[0], Insecure: insecure}
|
||||
if err := client.Pull(cmd.Context(), &request, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return client.Pull(cmd.Context(), &request, fn)
|
||||
}
|
||||
|
||||
type generateContextKey string
|
||||
@ -835,7 +895,7 @@ type runOptions struct {
|
||||
Format string
|
||||
System string
|
||||
Images []api.ImageData
|
||||
Options map[string]interface{}
|
||||
Options map[string]any
|
||||
MultiModal bool
|
||||
KeepAlive *api.Duration
|
||||
}
|
||||
@ -1364,7 +1424,6 @@ func NewCLI() *cobra.Command {
|
||||
envVars["OLLAMA_NOPRUNE"],
|
||||
envVars["OLLAMA_ORIGINS"],
|
||||
envVars["OLLAMA_SCHED_SPREAD"],
|
||||
envVars["OLLAMA_TMPDIR"],
|
||||
envVars["OLLAMA_FLASH_ATTENTION"],
|
||||
envVars["OLLAMA_KV_CACHE_TYPE"],
|
||||
envVars["OLLAMA_LLM_LIBRARY"],
|
||||
|
@ -2,7 +2,6 @@ package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -16,6 +15,7 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestShowInfo(t *testing.T) {
|
||||
@ -87,6 +87,8 @@ func TestShowInfo(t *testing.T) {
|
||||
ModelInfo: map[string]any{
|
||||
"general.architecture": "test",
|
||||
"general.parameter_count": float64(8_000_000_000),
|
||||
"some.true_bool": true,
|
||||
"some.false_bool": false,
|
||||
"test.context_length": float64(1000),
|
||||
"test.embedding_length": float64(11434),
|
||||
},
|
||||
@ -111,6 +113,8 @@ func TestShowInfo(t *testing.T) {
|
||||
Metadata
|
||||
general.architecture test
|
||||
general.parameter_count 8e+09
|
||||
some.false_bool false
|
||||
some.true_bool true
|
||||
test.context_length 1000
|
||||
test.embedding_length 11434
|
||||
|
||||
@ -256,6 +260,34 @@ Weigh anchor!
|
||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("capabilities", func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
if err := showInfo(&api.ShowResponse{
|
||||
Details: api.ModelDetails{
|
||||
Family: "test",
|
||||
ParameterSize: "7B",
|
||||
QuantizationLevel: "FP16",
|
||||
},
|
||||
Capabilities: []model.Capability{model.CapabilityVision, model.CapabilityTools},
|
||||
}, false, &b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expect := " Model\n" +
|
||||
" architecture test \n" +
|
||||
" parameters 7B \n" +
|
||||
" quantization FP16 \n" +
|
||||
"\n" +
|
||||
" Capabilities\n" +
|
||||
" vision \n" +
|
||||
" tools \n" +
|
||||
"\n"
|
||||
|
||||
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteHandler(t *testing.T) {
|
||||
@ -304,7 +336,7 @@ func TestDeleteHandler(t *testing.T) {
|
||||
t.Cleanup(mockServer.Close)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(context.TODO())
|
||||
cmd.SetContext(t.Context())
|
||||
if err := DeleteHandler(cmd, []string{"test-model"}); err != nil {
|
||||
t.Fatalf("DeleteHandler failed: %v", err)
|
||||
}
|
||||
@ -366,11 +398,6 @@ func TestGetModelfileName(t *testing.T) {
|
||||
var expectedFilename string
|
||||
|
||||
if tt.fileExists {
|
||||
tempDir, err := os.MkdirTemp("", "modelfiledir")
|
||||
defer os.RemoveAll(tempDir)
|
||||
if err != nil {
|
||||
t.Fatalf("temp modelfile dir creation failed: %v", err)
|
||||
}
|
||||
var fn string
|
||||
if tt.modelfileName != "" {
|
||||
fn = tt.modelfileName
|
||||
@ -378,10 +405,11 @@ func TestGetModelfileName(t *testing.T) {
|
||||
fn = "Modelfile"
|
||||
}
|
||||
|
||||
tempFile, err := os.CreateTemp(tempDir, fn)
|
||||
tempFile, err := os.CreateTemp(t.TempDir(), fn)
|
||||
if err != nil {
|
||||
t.Fatalf("temp modelfile creation failed: %v", err)
|
||||
}
|
||||
defer tempFile.Close()
|
||||
|
||||
expectedFilename = tempFile.Name()
|
||||
err = cmd.Flags().Set("file", expectedFilename)
|
||||
@ -496,7 +524,7 @@ func TestPushHandler(t *testing.T) {
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().Bool("insecure", false, "")
|
||||
cmd.SetContext(context.TODO())
|
||||
cmd.SetContext(t.Context())
|
||||
|
||||
// Redirect stderr to capture progress output
|
||||
oldStderr := os.Stderr
|
||||
@ -601,7 +629,7 @@ func TestListHandler(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(context.TODO())
|
||||
cmd.SetContext(t.Context())
|
||||
|
||||
// Capture stdout
|
||||
oldStdout := os.Stdout
|
||||
@ -656,7 +684,7 @@ func TestCreateHandler(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name != "test-model" {
|
||||
if req.Model != "test-model" {
|
||||
t.Errorf("expected model name 'test-model', got %s", req.Name)
|
||||
}
|
||||
|
||||
@ -696,7 +724,7 @@ func TestCreateHandler(t *testing.T) {
|
||||
}))
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
t.Cleanup(mockServer.Close)
|
||||
tempFile, err := os.CreateTemp("", "modelfile")
|
||||
tempFile, err := os.CreateTemp(t.TempDir(), "modelfile")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -716,7 +744,7 @@ func TestCreateHandler(t *testing.T) {
|
||||
}
|
||||
|
||||
cmd.Flags().Bool("insecure", false, "")
|
||||
cmd.SetContext(context.TODO())
|
||||
cmd.SetContext(t.Context())
|
||||
|
||||
// Redirect stderr to capture progress output
|
||||
oldStderr := os.Stderr
|
||||
|
@ -44,7 +44,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
|
||||
|
||||
if opts.MultiModal {
|
||||
fmt.Fprintf(os.Stderr, "Use %s to include .jpg or .png images.\n", filepath.FromSlash("/path/to/file"))
|
||||
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, or .webp images.\n", filepath.FromSlash("/path/to/file"))
|
||||
}
|
||||
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
@ -503,6 +503,7 @@ func normalizeFilePath(fp string) string {
|
||||
"\\\\", "\\", // Escaped backslash
|
||||
"\\*", "*", // Escaped asterisk
|
||||
"\\?", "?", // Escaped question mark
|
||||
"\\~", "~", // Escaped tilde
|
||||
).Replace(fp)
|
||||
}
|
||||
|
||||
@ -510,7 +511,7 @@ func extractFileNames(input string) []string {
|
||||
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
|
||||
// and followed by more characters and a file extension
|
||||
// This will capture non filename strings, but we'll check for file existence to remove mismatches
|
||||
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png)\b`
|
||||
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp)\b`
|
||||
re := regexp.MustCompile(regexPattern)
|
||||
|
||||
return re.FindAllString(input, -1)
|
||||
@ -530,6 +531,8 @@ func extractFileData(input string) (string, []api.ImageData, error) {
|
||||
return "", imgs, err
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
|
||||
input = strings.ReplaceAll(input, "'"+nfp+"'", "")
|
||||
input = strings.ReplaceAll(input, "'"+fp+"'", "")
|
||||
input = strings.ReplaceAll(input, fp, "")
|
||||
imgs = append(imgs, data)
|
||||
}
|
||||
@ -550,7 +553,7 @@ func getImageData(filePath string) ([]byte, error) {
|
||||
}
|
||||
|
||||
contentType := http.DetectContentType(buf)
|
||||
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
|
||||
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp"}
|
||||
if !slices.Contains(allowedTypes, contentType) {
|
||||
return nil, fmt.Errorf("invalid image type: %s", contentType)
|
||||
}
|
||||
|
@ -1,6 +1,8 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -10,14 +12,17 @@ func TestExtractFilenames(t *testing.T) {
|
||||
// Unix style paths
|
||||
input := ` some preamble
|
||||
./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2 ./1.svg
|
||||
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.JPG`
|
||||
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.JPG
|
||||
/unescaped space /six.webp inbetween6 /valid\ path/dir/seven.WEBP`
|
||||
res := extractFileNames(input)
|
||||
assert.Len(t, res, 5)
|
||||
assert.Len(t, res, 7)
|
||||
assert.Contains(t, res[0], "one.png")
|
||||
assert.Contains(t, res[1], "two.jpg")
|
||||
assert.Contains(t, res[2], "three.jpeg")
|
||||
assert.Contains(t, res[3], "four.png")
|
||||
assert.Contains(t, res[4], "five.JPG")
|
||||
assert.Contains(t, res[5], "six.webp")
|
||||
assert.Contains(t, res[6], "seven.WEBP")
|
||||
assert.NotContains(t, res[4], '"')
|
||||
assert.NotContains(t, res, "inbetween1")
|
||||
assert.NotContains(t, res, "./1.svg")
|
||||
@ -28,10 +33,12 @@ func TestExtractFilenames(t *testing.T) {
|
||||
/absolute/nospace/three.jpeg inbetween3 /absolute/with space/four.png inbetween4
|
||||
./relative\ path/five.JPG inbetween5 "./relative with/spaces/six.png inbetween6
|
||||
d:\path with\spaces\seven.JPEG inbetween7 c:\users\jdoe\eight.png inbetween8
|
||||
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.PNG some ending
|
||||
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.PNG
|
||||
c:/users/jdoe/eleven.webp inbetween11 c:/program files/someplace/twelve.WebP inbetween12
|
||||
d:\path with\spaces\thirteen.WEBP some ending
|
||||
`
|
||||
res = extractFileNames(input)
|
||||
assert.Len(t, res, 10)
|
||||
assert.Len(t, res, 13)
|
||||
assert.NotContains(t, res, "inbetween2")
|
||||
assert.Contains(t, res[0], "one.png")
|
||||
assert.Contains(t, res[0], "c:")
|
||||
@ -49,4 +56,31 @@ d:\path with\spaces\seven.JPEG inbetween7 c:\users\jdoe\eight.png inbetween8
|
||||
assert.Contains(t, res[8], "d:")
|
||||
assert.Contains(t, res[9], "ten.PNG")
|
||||
assert.Contains(t, res[9], "E:")
|
||||
assert.Contains(t, res[10], "eleven.webp")
|
||||
assert.Contains(t, res[10], "c:")
|
||||
assert.Contains(t, res[11], "twelve.WebP")
|
||||
assert.Contains(t, res[11], "c:")
|
||||
assert.Contains(t, res[12], "thirteen.WEBP")
|
||||
assert.Contains(t, res[12], "d:")
|
||||
}
|
||||
|
||||
// Ensure that file paths wrapped in single quotes are removed with the quotes.
|
||||
func TestExtractFileDataRemovesQuotedFilepath(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
fp := filepath.Join(dir, "img.jpg")
|
||||
data := make([]byte, 600)
|
||||
copy(data, []byte{
|
||||
0xff, 0xd8, 0xff, 0xe0, 0x00, 0x10, 'J', 'F', 'I', 'F',
|
||||
0x00, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0xff, 0xd9,
|
||||
})
|
||||
if err := os.WriteFile(fp, data, 0o600); err != nil {
|
||||
t.Fatalf("failed to write test image: %v", err)
|
||||
}
|
||||
|
||||
input := "before '" + fp + "' after"
|
||||
cleaned, imgs, err := extractFileData(input)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, imgs, 1)
|
||||
assert.Equal(t, cleaned, "before after")
|
||||
}
|
||||
|
@ -1,25 +1,26 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type ModelParameters struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
TextModel TextParameters `json:"text_config"`
|
||||
}
|
||||
Architectures []string `json:"architectures"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
|
||||
type TextParameters struct {
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
TextModel struct {
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
} `json:"text_config"`
|
||||
}
|
||||
|
||||
type AdapterParameters struct {
|
||||
@ -84,27 +85,17 @@ func (ModelParameters) specialTokenTypes() []string {
|
||||
}
|
||||
}
|
||||
|
||||
func (ModelParameters) writeFile(ws io.WriteSeeker, kv ggml.KV, ts []ggml.Tensor) error {
|
||||
return ggml.WriteGGUF(ws, kv, ts)
|
||||
}
|
||||
|
||||
func (AdapterParameters) writeFile(ws io.WriteSeeker, kv ggml.KV, ts []ggml.Tensor) error {
|
||||
return ggml.WriteGGUF(ws, kv, ts)
|
||||
}
|
||||
|
||||
type ModelConverter interface {
|
||||
// KV maps parameters to LLM key-values
|
||||
KV(*Tokenizer) ggml.KV
|
||||
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
|
||||
Tensors([]Tensor) []ggml.Tensor
|
||||
Tensors([]Tensor) []*ggml.Tensor
|
||||
// Replacements returns a list of string pairs to replace in tensor names.
|
||||
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
|
||||
Replacements() []string
|
||||
|
||||
// specialTokenTypes returns any special token types the model uses
|
||||
specialTokenTypes() []string
|
||||
// writeFile writes the model to the provided io.WriteSeeker
|
||||
writeFile(io.WriteSeeker, ggml.KV, []ggml.Tensor) error
|
||||
}
|
||||
|
||||
type moreParser interface {
|
||||
@ -115,15 +106,13 @@ type AdapterConverter interface {
|
||||
// KV maps parameters to LLM key-values
|
||||
KV(ggml.KV) ggml.KV
|
||||
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
|
||||
Tensors([]Tensor) []ggml.Tensor
|
||||
Tensors([]Tensor) []*ggml.Tensor
|
||||
// Replacements returns a list of string pairs to replace in tensor names.
|
||||
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
|
||||
Replacements() []string
|
||||
|
||||
writeFile(io.WriteSeeker, ggml.KV, []ggml.Tensor) error
|
||||
}
|
||||
|
||||
func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV ggml.KV) error {
|
||||
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
|
||||
bts, err := fs.ReadFile(fsys, "adapter_config.json")
|
||||
if err != nil {
|
||||
return err
|
||||
@ -158,14 +147,14 @@ func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV ggml.KV) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return conv.writeFile(ws, conv.KV(baseKV), conv.Tensors(ts))
|
||||
return writeFile(f, conv.KV(baseKV), conv.Tensors(ts))
|
||||
}
|
||||
|
||||
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
|
||||
// and files it finds in the input path.
|
||||
// Supported input model formats include safetensors.
|
||||
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
|
||||
func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
||||
func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
bts, err := fs.ReadFile(fsys, "config.json")
|
||||
if err != nil {
|
||||
return err
|
||||
@ -182,8 +171,14 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
||||
|
||||
var conv ModelConverter
|
||||
switch p.Architectures[0] {
|
||||
case "LlamaForCausalLM", "MistralForCausalLM":
|
||||
case "LlamaForCausalLM":
|
||||
conv = &llamaModel{}
|
||||
case "MllamaForConditionalGeneration":
|
||||
conv = &mllamaModel{}
|
||||
case "Llama4ForConditionalGeneration":
|
||||
conv = &llama4Model{}
|
||||
case "Mistral3ForConditionalGeneration":
|
||||
conv = &mistral3Model{}
|
||||
case "MixtralForCausalLM":
|
||||
conv = &mixtralModel{}
|
||||
case "GemmaForCausalLM":
|
||||
@ -196,12 +191,14 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
||||
conv = &phi3Model{}
|
||||
case "Qwen2ForCausalLM":
|
||||
conv = &qwen2Model{}
|
||||
case "Qwen2_5_VLForConditionalGeneration":
|
||||
conv = &qwen25VLModel{}
|
||||
case "BertModel":
|
||||
conv = &bertModel{}
|
||||
case "CohereForCausalLM":
|
||||
conv = &commandrModel{}
|
||||
default:
|
||||
return errors.New("unsupported architecture")
|
||||
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bts, conv); err != nil {
|
||||
@ -219,24 +216,22 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
||||
return err
|
||||
}
|
||||
|
||||
vocabSize := int(p.VocabSize)
|
||||
if vocabSize == 0 {
|
||||
tVocabSize := int(p.TextModel.VocabSize)
|
||||
vocabSize = tVocabSize
|
||||
}
|
||||
vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize))
|
||||
|
||||
switch {
|
||||
case vocabSize == 0:
|
||||
slog.Warn("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens))
|
||||
slog.Debug("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens))
|
||||
case vocabSize > len(t.Vocabulary.Tokens):
|
||||
slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
|
||||
slog.Debug("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
|
||||
for i := range vocabSize - len(t.Vocabulary.Tokens) {
|
||||
t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i))
|
||||
t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1)
|
||||
t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined)
|
||||
}
|
||||
case vocabSize < len(t.Vocabulary.Tokens):
|
||||
return fmt.Errorf("vocabulary is larger than expected '%d' instead of '%d'", len(t.Vocabulary.Tokens), vocabSize)
|
||||
slog.Debug("vocabulary is larger than expected", "want", vocabSize, "got", len(t.Vocabulary.Tokens))
|
||||
p.VocabSize = uint32(len(t.Vocabulary.Tokens))
|
||||
p.TextModel.VocabSize = uint32(len(t.Vocabulary.Tokens))
|
||||
default:
|
||||
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
|
||||
}
|
||||
@ -246,5 +241,13 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return conv.writeFile(ws, conv.KV(t), conv.Tensors(ts))
|
||||
return writeFile(f, conv.KV(t), conv.Tensors(ts))
|
||||
}
|
||||
|
||||
func writeFile(f *os.File, kv ggml.KV, ts []*ggml.Tensor) error {
|
||||
for i := range ts {
|
||||
ts[i].Shape = slices.Clone(ts[i].Shape)
|
||||
slices.Reverse(ts[i].Shape)
|
||||
}
|
||||
return ggml.WriteGGUF(f, kv, ts)
|
||||
}
|
||||
|
@ -132,8 +132,8 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *bertModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
func (p *bertModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
for _, t := range ts {
|
||||
if slices.Contains([]string{
|
||||
"embeddings.position_ids",
|
||||
@ -143,7 +143,7 @@ func (p *bertModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
continue
|
||||
}
|
||||
|
||||
out = append(out, ggml.Tensor{
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
|
@ -43,10 +43,10 @@ func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *commandrModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
func (p *commandrModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
for _, t := range ts {
|
||||
out = append(out, ggml.Tensor{
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
|
@ -42,14 +42,14 @@ func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *gemmaModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
func (p *gemmaModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
for _, t := range ts {
|
||||
if !strings.HasPrefix(t.Name(), "v.") && strings.HasSuffix(t.Name(), "_norm.weight") {
|
||||
t.SetRepacker(p.addOne)
|
||||
}
|
||||
|
||||
out = append(out, ggml.Tensor{
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
|
@ -21,8 +21,8 @@ func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *gemma2Adapter) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
func (p *gemma2Adapter) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
for _, t := range ts {
|
||||
shape := t.Shape()
|
||||
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
|
||||
@ -31,7 +31,7 @@ func (p *gemma2Adapter) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
t.SetRepacker(p.repack)
|
||||
}
|
||||
|
||||
out = append(out, ggml.Tensor{
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
|
@ -28,12 +28,12 @@ type llamaModel struct {
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RopeScaling struct {
|
||||
Type string `json:"type"`
|
||||
RopeType string `json:"rope_type"`
|
||||
Factor float32 `json:"factor"`
|
||||
LowFrequencyFactor float32 `json:"low_freq_factor"`
|
||||
HighFrequencyFactor float32 `json:"high_freq_factor"`
|
||||
OriginalMaxPositionalEmbeddings uint32 `json:"original_max_positional_embeddings"`
|
||||
Type string `json:"type"`
|
||||
RopeType string `json:"rope_type"`
|
||||
Factor float32 `json:"factor"`
|
||||
LowFrequencyFactor float32 `json:"low_freq_factor"`
|
||||
HighFrequencyFactor float32 `json:"high_freq_factor"`
|
||||
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||
|
||||
factors ropeFactor
|
||||
} `json:"rope_scaling"`
|
||||
@ -42,6 +42,8 @@ type llamaModel struct {
|
||||
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
||||
NormEpsilon float32 `json:"norm_epsilon"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
|
||||
skipRepack bool
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*llamaModel)(nil)
|
||||
@ -70,6 +72,10 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv["llama.rope.dimension_count"] = p.HiddenSize / headCount
|
||||
}
|
||||
|
||||
if p.HeadDim > 0 {
|
||||
kv["llama.attention.head_dim"] = p.HeadDim
|
||||
}
|
||||
|
||||
if p.RopeTheta > 0 {
|
||||
kv["llama.rope.freq_base"] = p.RopeTheta
|
||||
}
|
||||
@ -84,7 +90,7 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
|
||||
factorLow := cmp.Or(p.RopeScaling.LowFrequencyFactor, 1.0)
|
||||
factorHigh := cmp.Or(p.RopeScaling.HighFrequencyFactor, 4.0)
|
||||
|
||||
original := cmp.Or(p.RopeScaling.OriginalMaxPositionalEmbeddings, 8192)
|
||||
original := cmp.Or(p.RopeScaling.OriginalMaxPositionEmbeddings, 8192)
|
||||
lambdaLow := float32(original) / factorLow
|
||||
lambdaHigh := float32(original) / factorHigh
|
||||
|
||||
@ -120,11 +126,11 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *llamaModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
func (p *llamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
if p.RopeScaling.factors != nil {
|
||||
out = append(out, ggml.Tensor{
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: "rope_freqs.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{uint64(len(p.RopeScaling.factors))},
|
||||
@ -133,12 +139,13 @@ func (p *llamaModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
}
|
||||
|
||||
for _, t := range ts {
|
||||
if strings.HasSuffix(t.Name(), "attn_q.weight") ||
|
||||
strings.HasSuffix(t.Name(), "attn_k.weight") {
|
||||
t.SetRepacker(p.repack)
|
||||
if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") {
|
||||
if !p.skipRepack {
|
||||
t.SetRepacker(p.repack)
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, ggml.Tensor{
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
|
169
convert/convert_llama4.go
Normal file
169
convert/convert_llama4.go
Normal file
@ -0,0 +1,169 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type llama4Model struct {
|
||||
ModelParameters
|
||||
TextModel struct {
|
||||
llamaModel
|
||||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||
NumLocalExperts uint32 `json:"num_local_experts"`
|
||||
InterleaveMOELayerStep uint32 `json:"interleave_moe_layer_step"`
|
||||
UseQKNorm bool `json:"use_qk_norm"`
|
||||
IntermediateSizeMLP uint32 `json:"intermediate_size_mlp"`
|
||||
AttentionChunkSize uint32 `json:"attention_chunk_size"`
|
||||
} `json:"text_config"`
|
||||
VisionModel struct {
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
ImageSize uint32 `json:"image_size"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
NormEpsilon float32 `json:"norm_eps"`
|
||||
PixelShuffleRatio float32 `json:"pixel_shuffle_ratio"`
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
||||
// KV implements ModelConverter.
|
||||
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "llama4"
|
||||
|
||||
for k, v := range p.TextModel.KV(t) {
|
||||
if strings.HasPrefix(k, "llama.") {
|
||||
kv[strings.ReplaceAll(k, "llama.", "llama4.")] = v
|
||||
}
|
||||
}
|
||||
|
||||
kv["llama4.feed_forward_length"] = p.TextModel.IntermediateSizeMLP
|
||||
kv["llama4.expert_feed_forward_length"] = p.TextModel.IntermediateSize
|
||||
|
||||
kv["llama4.expert_count"] = p.TextModel.NumLocalExperts
|
||||
kv["llama4.expert_used_count"] = p.TextModel.NumExpertsPerToken
|
||||
kv["llama4.interleave_moe_layer_step"] = p.TextModel.InterleaveMOELayerStep
|
||||
kv["llama4.use_qk_norm"] = p.TextModel.UseQKNorm
|
||||
kv["llama4.attention.chunk_size"] = p.TextModel.AttentionChunkSize
|
||||
|
||||
kv["llama4.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
||||
kv["llama4.vision.embedding_length"] = p.VisionModel.HiddenSize
|
||||
kv["llama4.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
|
||||
kv["llama4.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
|
||||
kv["llama4.vision.image_size"] = p.VisionModel.ImageSize
|
||||
kv["llama4.vision.patch_size"] = p.VisionModel.PatchSize
|
||||
kv["llama4.vision.rope.freq_base"] = p.VisionModel.RopeTheta
|
||||
kv["llama4.vision.layer_norm_epsilon"] = p.VisionModel.NormEpsilon
|
||||
kv["llama4.vision.pixel_shuffle_ratio"] = p.VisionModel.PixelShuffleRatio
|
||||
return kv
|
||||
}
|
||||
|
||||
// Replacements implements ModelConverter.
|
||||
func (p *llama4Model) Replacements() []string {
|
||||
return append(
|
||||
p.TextModel.Replacements(),
|
||||
"language_model.", "",
|
||||
"vision_model", "v",
|
||||
"multi_modal_projector", "mm",
|
||||
"feed_forward.down_proj", "ffn_down",
|
||||
"feed_forward.up_proj", "ffn_up",
|
||||
"feed_forward.gate_proj", "ffn_gate",
|
||||
"feed_forward.", "ffn_",
|
||||
"shared_expert.down_proj", "down_shexp",
|
||||
"shared_expert.gate_proj", "gate_shexp",
|
||||
"shared_expert.up_proj", "up_shexp",
|
||||
"experts.down_proj", "down_exps.weight",
|
||||
"experts.gate_up_proj", "gate_up_exps.weight",
|
||||
"router", "gate_inp",
|
||||
"patch_embedding.linear", "patch_embedding",
|
||||
)
|
||||
}
|
||||
|
||||
// Tensors implements ModelConverter.
|
||||
func (p *llama4Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
var textTensors []Tensor
|
||||
for _, t := range ts {
|
||||
if strings.HasPrefix(t.Name(), "v.") || strings.HasPrefix(t.Name(), "mm.") {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
} else if strings.Contains(t.Name(), "ffn_gate_up_exps") {
|
||||
// gate and up projectors are fused
|
||||
// dims[1], dims[2] must be swapped
|
||||
// [experts, hidden_size, intermediate_size * 2] --> [experts, intermediate_size, hidden_size]
|
||||
halfDim := int(t.Shape()[2]) / 2
|
||||
|
||||
newShape := slices.Clone(t.Shape())
|
||||
newShape[1], newShape[2] = newShape[2]/2, newShape[1]
|
||||
for i, name := range []string{"ffn_gate_exps", "ffn_up_exps"} {
|
||||
// clone tensor since we need separate repackers
|
||||
tt := t.Clone()
|
||||
tt.SetRepacker(p.repack(nil, nil, tensor.S(i*halfDim, (i+1)*halfDim)))
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.ReplaceAll(tt.Name(), "ffn_gate_up_exps", name),
|
||||
Kind: tt.Kind(),
|
||||
Shape: newShape,
|
||||
WriterTo: tt,
|
||||
})
|
||||
}
|
||||
} else if strings.Contains(t.Name(), "ffn_down_exps") {
|
||||
// dims[1], dims[2] must be swapped
|
||||
// [experts, intermediate_size, hidden_size] --> [experts, hidden_size, intermediate_size]
|
||||
t.SetRepacker(p.repack())
|
||||
newShape := slices.Clone(t.Shape())
|
||||
newShape[1], newShape[2] = newShape[2], newShape[1]
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: newShape,
|
||||
WriterTo: t,
|
||||
})
|
||||
} else {
|
||||
textTensors = append(textTensors, t)
|
||||
}
|
||||
}
|
||||
|
||||
p.TextModel.skipRepack = true
|
||||
out = append(out, p.TextModel.Tensors(textTensors)...)
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *llama4Model) repack(slice ...tensor.Slice) Repacker {
|
||||
return func(name string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i, dim := range shape {
|
||||
dims[i] = int(dim)
|
||||
}
|
||||
|
||||
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
t, err := t.Slice(slice...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := t.T(0, 2, 1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t = tensor.Materialize(t)
|
||||
// flatten tensor so it can be return as a vector
|
||||
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return native.VectorF32(t.(*tensor.Dense))
|
||||
}
|
||||
}
|
@ -29,8 +29,8 @@ func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *llamaAdapter) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
func (p *llamaAdapter) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
for _, t := range ts {
|
||||
shape := t.Shape()
|
||||
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
|
||||
@ -41,7 +41,7 @@ func (p *llamaAdapter) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
t.SetRepacker(p.repack)
|
||||
}
|
||||
|
||||
out = append(out, ggml.Tensor{
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: shape,
|
||||
|
190
convert/convert_mistral.go
Normal file
190
convert/convert_mistral.go
Normal file
@ -0,0 +1,190 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type mistral3Model struct {
|
||||
ModelParameters
|
||||
ImageTokenIndex uint32 `json:"image_token_index"`
|
||||
SpatialMergeSize uint32 `json:"spatial_merge_size"`
|
||||
VisionFeatureLayer int32 `json:"vision_feature_layer"`
|
||||
TextModel struct {
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
SlidingWindow *uint32 `json:"sliding_window"`
|
||||
HiddenAct string `json:"hidden_act"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
} `json:"text_config"`
|
||||
VisionModel struct {
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
ImageSize uint32 `json:"image_size"`
|
||||
NumChannels uint32 `json:"num_channels"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
HiddenAct string `json:"hidden_act"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
} `json:"vision_config"`
|
||||
MultiModalProjectorBias bool `json:"multimodal_projector_bias"`
|
||||
ProjectorHiddenAct string `json:"projector_hidden_act"`
|
||||
}
|
||||
|
||||
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "mistral3"
|
||||
kv["mistral3.vocab_size"] = p.TextModel.VocabSize
|
||||
|
||||
// Text configuration
|
||||
kv["mistral3.block_count"] = p.TextModel.NumHiddenLayers
|
||||
kv["mistral3.context_length"] = p.TextModel.MaxPositionEmbeddings
|
||||
kv["mistral3.embedding_length"] = p.TextModel.HiddenSize
|
||||
kv["mistral3.feed_forward_length"] = p.TextModel.IntermediateSize
|
||||
kv["mistral3.attention.head_count"] = p.TextModel.NumAttentionHeads
|
||||
kv["mistral3.attention.head_count_kv"] = p.TextModel.NumKeyValueHeads
|
||||
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
|
||||
kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
|
||||
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
|
||||
kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
|
||||
kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta
|
||||
|
||||
// Vision configuration
|
||||
kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
||||
kv["mistral3.vision.embedding_length"] = p.VisionModel.HiddenSize
|
||||
kv["mistral3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
|
||||
kv["mistral3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
|
||||
kv["mistral3.vision.attention.key_length"] = p.VisionModel.HeadDim
|
||||
kv["mistral3.vision.image_size"] = p.VisionModel.ImageSize
|
||||
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
|
||||
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
|
||||
// kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value
|
||||
kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta
|
||||
|
||||
// Multimodal configuration
|
||||
kv["mistral3.image_token_index"] = p.ImageTokenIndex
|
||||
kv["mistral3.spatial_merge_size"] = p.SpatialMergeSize
|
||||
|
||||
kv["mistral3.mm.projector_bias"] = p.MultiModalProjectorBias
|
||||
|
||||
if p.ProjectorHiddenAct != "" {
|
||||
kv["mistral3.mm.projector_hidden_act"] = p.ProjectorHiddenAct
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *mistral3Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
for _, t := range ts {
|
||||
if !strings.HasPrefix(t.Name(), "v.") {
|
||||
if strings.HasSuffix(t.Name(), ".attn_q.weight") ||
|
||||
strings.HasSuffix(t.Name(), ".attn_k.weight") {
|
||||
t.SetRepacker(p.repack)
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *mistral3Model) Replacements() []string {
|
||||
return []string{
|
||||
"language_model.model.norm", "output_norm",
|
||||
"language_model.model.", "",
|
||||
"language_model.", "",
|
||||
"layers", "blk",
|
||||
"transformer.layers", "blk",
|
||||
"vision_tower", "v",
|
||||
"ln_pre", "encoder_norm",
|
||||
"input_layernorm", "attn_norm",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"embed_tokens", "token_embd",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"attention.q_proj", "attn_q",
|
||||
"attention.k_proj", "attn_k",
|
||||
"attention.v_proj", "attn_v",
|
||||
"attention.o_proj", "attn_output",
|
||||
"attention_norm", "attn_norm",
|
||||
"feed_forward.gate_proj", "ffn_gate",
|
||||
"feed_forward.down_proj", "ffn_down",
|
||||
"feed_forward.up_proj", "ffn_up",
|
||||
"multi_modal_projector", "mm",
|
||||
"ffn_norm", "ffn_norm",
|
||||
"lm_head", "output",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *mistral3Model) repack(name string, data []float32, shape []uint64) ([]float32, error) {
|
||||
var dims []int
|
||||
for _, dim := range shape {
|
||||
dims = append(dims, int(dim))
|
||||
}
|
||||
|
||||
var heads uint32
|
||||
if strings.HasSuffix(name, ".attn_q.weight") {
|
||||
heads = p.TextModel.NumAttentionHeads
|
||||
} else if strings.HasSuffix(name, ".attn_k.weight") {
|
||||
heads = cmp.Or(p.TextModel.NumKeyValueHeads, p.TextModel.NumAttentionHeads)
|
||||
} else {
|
||||
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
|
||||
}
|
||||
|
||||
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.T(0, 2, 1, 3); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.Reshape(dims...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.Transpose(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ts, err := native.SelectF32(n, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var f32s []float32
|
||||
for _, t := range ts {
|
||||
f32s = append(f32s, t...)
|
||||
}
|
||||
|
||||
return f32s, nil
|
||||
}
|
@ -29,7 +29,7 @@ func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *mixtralModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
func (p *mixtralModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
oldnew := []string{
|
||||
"model.layers", "blk",
|
||||
"w1", "ffn_gate_exps",
|
||||
@ -56,10 +56,10 @@ func (p *mixtralModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
return true
|
||||
})
|
||||
|
||||
var out []ggml.Tensor
|
||||
var out []*ggml.Tensor
|
||||
for n, e := range experts {
|
||||
// TODO(mxyng): sanity check experts
|
||||
out = append(out, ggml.Tensor{
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: n,
|
||||
Kind: e[0].Kind(),
|
||||
Shape: append([]uint64{uint64(len(e))}, e[0].Shape()...),
|
||||
|
160
convert/convert_mllama.go
Normal file
160
convert/convert_mllama.go
Normal file
@ -0,0 +1,160 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
)
|
||||
|
||||
type mllamaModel struct {
|
||||
ModelParameters
|
||||
TextModel struct {
|
||||
llamaModel
|
||||
|
||||
CrossAttentionLayers []int32 `json:"cross_attention_layers"`
|
||||
} `json:"text_config"`
|
||||
VisionModel struct {
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumGlobalLayers uint32 `json:"num_global_layers"`
|
||||
IntermediateLayersIndices []int32 `json:"intermediate_layers_indices"`
|
||||
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
|
||||
AttentionHeads uint32 `json:"attention_heads"`
|
||||
|
||||
ImageSize uint32 `json:"image_size"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
NumChannels uint32 `json:"num_channels"`
|
||||
MaxNumTiles uint32 `json:"max_num_tiles"`
|
||||
NormEpsilon float32 `json:"norm_eps"`
|
||||
RopeTheta float32 `json:"rope.freq_base"`
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *mllamaModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "mllama"
|
||||
|
||||
for k, v := range m.TextModel.KV(t) {
|
||||
if strings.HasPrefix(k, "llama.") {
|
||||
kv[strings.ReplaceAll(k, "llama.", "mllama.")] = v
|
||||
}
|
||||
}
|
||||
|
||||
kv["mllama.attention.cross_attention_layers"] = m.TextModel.CrossAttentionLayers
|
||||
|
||||
kv["mllama.vision.block_count"] = m.VisionModel.NumHiddenLayers
|
||||
kv["mllama.vision.global.block_count"] = m.VisionModel.NumGlobalLayers
|
||||
kv["mllama.vision.intermediate_layers_indices"] = m.VisionModel.IntermediateLayersIndices
|
||||
|
||||
kv["mllama.vision.embedding_length"] = m.VisionModel.HiddenSize
|
||||
kv["mllama.vision.feed_forward_length"] = m.VisionModel.IntermediateSize
|
||||
|
||||
kv["mllama.vision.attention.head_count"] = m.VisionModel.AttentionHeads
|
||||
kv["mllama.vision.attention.layer_norm_epsilon"] = m.VisionModel.NormEpsilon
|
||||
|
||||
kv["mllama.vision.image_size"] = m.VisionModel.ImageSize
|
||||
kv["mllama.vision.patch_size"] = m.VisionModel.PatchSize
|
||||
kv["mllama.vision.max_num_tiles"] = m.VisionModel.MaxNumTiles
|
||||
kv["mllama.vision.num_channels"] = m.VisionModel.NumChannels
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (m *mllamaModel) Replacements() []string {
|
||||
return append(
|
||||
m.TextModel.Replacements(),
|
||||
"language_model.", "",
|
||||
"gate_attn", "attn_gate",
|
||||
"gate_ffn", "ffn_gate",
|
||||
"cross_attn.", "cross_attn_",
|
||||
"vision_model", "v",
|
||||
"class_embedding", "class_embd",
|
||||
"patch_embedding", "patch_embd",
|
||||
"gated_positional_embedding.tile_embedding", "tile_position_embd",
|
||||
"gated_positional_embedding.embedding", "position_embd.weight",
|
||||
"gated_positional_embedding", "position_embd",
|
||||
"embedding.weight", "weight",
|
||||
"pre_tile_positional_embedding", "pre_tile_position_embd",
|
||||
"post_tile_positional_embedding", "post_tile_position_embd",
|
||||
"layernorm_pre", "pre_ln",
|
||||
"layernorm_post", "post_ln",
|
||||
"global_transformer.layers", "global.blk",
|
||||
"transformer.layers", "blk",
|
||||
"mlp.fc1", "ffn_up",
|
||||
"mlp.fc2", "ffn_down",
|
||||
"multi_modal_projector", "mm.0",
|
||||
)
|
||||
}
|
||||
|
||||
func (m *mllamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
var text []Tensor
|
||||
for _, t := range ts {
|
||||
if t.Name() == "v.position_embd.gate" {
|
||||
for _, name := range []string{"v.position_embd.gate", "v.tile_position_embd.gate"} {
|
||||
tt := t.Clone()
|
||||
tt.SetRepacker(m.repack(name))
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: tt,
|
||||
})
|
||||
}
|
||||
} else if t.Name() == "v.pre_tile_position_embd.gate" || t.Name() == "v.post_tile_position_embd.gate" {
|
||||
t.SetRepacker(m.repack(t.Name()))
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
} else if strings.HasPrefix(t.Name(), "v.") || strings.HasPrefix(t.Name(), "mm.") {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
} else {
|
||||
text = append(text, t)
|
||||
}
|
||||
}
|
||||
|
||||
return append(out, m.TextModel.Tensors(text)...)
|
||||
}
|
||||
|
||||
func (m *mllamaModel) repack(name string) Repacker {
|
||||
return func(_ string, data []float32, shape []uint64) (_ []float32, err error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i, dim := range shape {
|
||||
dims[i] = int(dim)
|
||||
}
|
||||
|
||||
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
|
||||
t, err = tensor.Tanh(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if name == "v.position_embd.gate" {
|
||||
t, err = tensor.Sub(float32(1), t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
t = tensor.Materialize(t)
|
||||
// flatten tensor so it can be return as a vector
|
||||
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return native.VectorF32(t.(*tensor.Dense))
|
||||
}
|
||||
}
|
@ -68,19 +68,19 @@ func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *phi3Model) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
func (p *phi3Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var addRopeFactors sync.Once
|
||||
|
||||
out := make([]ggml.Tensor, 0, len(ts)+2)
|
||||
out := make([]*ggml.Tensor, 0, len(ts)+2)
|
||||
for _, t := range ts {
|
||||
if strings.HasPrefix(t.Name(), "blk.0.") {
|
||||
addRopeFactors.Do(func() {
|
||||
out = append(out, ggml.Tensor{
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: "rope_factors_long.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{uint64(len(p.RopeScaling.LongFactor))},
|
||||
WriterTo: p.RopeScaling.LongFactor,
|
||||
}, ggml.Tensor{
|
||||
}, &ggml.Tensor{
|
||||
Name: "rope_factors_short.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{uint64(len(p.RopeScaling.ShortFactor))},
|
||||
@ -89,7 +89,7 @@ func (p *phi3Model) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
})
|
||||
}
|
||||
|
||||
out = append(out, ggml.Tensor{
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
@ -118,6 +118,5 @@ func (p *phi3Model) Replacements() []string {
|
||||
type ropeFactor []float32
|
||||
|
||||
func (r ropeFactor) WriteTo(w io.Writer) (int64, error) {
|
||||
err := binary.Write(w, binary.LittleEndian, r)
|
||||
return 0, err
|
||||
return 0, binary.Write(w, binary.LittleEndian, r)
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ type qwen2Model struct {
|
||||
Type string `json:"type"`
|
||||
Factor ropeFactor `json:"factor"`
|
||||
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||
MropeSection []int32 `json:"mrope_section"`
|
||||
} `json:"rope_scaling"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
}
|
||||
@ -39,16 +40,18 @@ func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
|
||||
case "yarn":
|
||||
kv["qwen2.rope.scaling.type"] = q.RopeScaling.Type
|
||||
kv["qwen2.rope.scaling.factor"] = q.RopeScaling.Factor
|
||||
case "mrope", "default":
|
||||
kv["qwen2.rope.mrope_section"] = q.RopeScaling.MropeSection
|
||||
default:
|
||||
panic("unknown rope scaling type")
|
||||
}
|
||||
return kv
|
||||
}
|
||||
|
||||
func (q *qwen2Model) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
func (q *qwen2Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
for _, t := range ts {
|
||||
out = append(out, ggml.Tensor{
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
|
102
convert/convert_qwen25vl.go
Normal file
102
convert/convert_qwen25vl.go
Normal file
@ -0,0 +1,102 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type qwen25VLModel struct {
|
||||
qwen2Model
|
||||
|
||||
VisionModel struct {
|
||||
Depth uint32 `json:"depth"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHeads uint32 `json:"num_heads"`
|
||||
InChannels uint32 `json:"in_chans"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
SpatialMergeSize uint32 `json:"spatial_merge_size"`
|
||||
SpatialPatchSize uint32 `json:"spatial_patch_size"`
|
||||
WindowSize uint32 `json:"window_size"`
|
||||
RMSNormEps float32 `json:"layer_norm_epsilon"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
FullAttentionBlocks []int32 `json:"fullatt_block_indexes"`
|
||||
TemporalPatchSize uint32 `json:"temporal_patch_size"`
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*qwen25VLModel)(nil)
|
||||
|
||||
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := q.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "qwen25vl"
|
||||
|
||||
for k, v := range q.qwen2Model.KV(t) {
|
||||
if strings.HasPrefix(k, "qwen2.") {
|
||||
kv[strings.Replace(k, "qwen2.", "qwen25vl.", 1)] = v
|
||||
}
|
||||
}
|
||||
|
||||
if q.VisionModel.FullAttentionBlocks == nil {
|
||||
kv["qwen25vl.vision.fullatt_block_indexes"] = []int32{7, 15, 23, 31}
|
||||
}
|
||||
|
||||
kv["qwen25vl.vision.block_count"] = cmp.Or(q.VisionModel.Depth, 32)
|
||||
kv["qwen25vl.vision.embedding_length"] = q.VisionModel.HiddenSize
|
||||
kv["qwen25vl.vision.attention.head_count"] = cmp.Or(q.VisionModel.NumHeads, 16)
|
||||
kv["qwen25vl.vision.num_channels"] = q.VisionModel.InChannels
|
||||
kv["qwen25vl.vision.patch_size"] = cmp.Or(q.VisionModel.PatchSize, 14)
|
||||
kv["qwen25vl.vision.spatial_merge_size"] = cmp.Or(q.VisionModel.SpatialMergeSize, 2)
|
||||
kv["qwen25vl.vision.spatial_patch_size"] = q.VisionModel.SpatialPatchSize
|
||||
kv["qwen25vl.vision.window_size"] = cmp.Or(q.VisionModel.WindowSize, 112)
|
||||
kv["qwen25vl.vision.attention.layer_norm_epsilon"] = cmp.Or(q.VisionModel.RMSNormEps, 1e-6)
|
||||
kv["qwen25vl.vision.rope.freq_base"] = cmp.Or(q.VisionModel.RopeTheta, 1e4)
|
||||
kv["qwen25vl.vision.fullatt_block_indexes"] = q.VisionModel.FullAttentionBlocks
|
||||
kv["qwen25vl.vision.temporal_patch_size"] = cmp.Or(q.VisionModel.TemporalPatchSize, 2)
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (q *qwen25VLModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
|
||||
for _, t := range ts {
|
||||
if strings.Contains(t.Name(), "patch_embed.proj") {
|
||||
for t := range splitDim(t, 2,
|
||||
strings.NewReplacer("patch_embed.proj", "patch_embd_0"),
|
||||
strings.NewReplacer("patch_embed.proj", "patch_embd_1"),
|
||||
) {
|
||||
t.Shape = slices.DeleteFunc(t.Shape, func(i uint64) bool { return i == 1 })
|
||||
out = append(out, t)
|
||||
}
|
||||
} else if strings.Contains(t.Name(), "attn.qkv") {
|
||||
out = append(out, slices.Collect(splitDim(t, 0,
|
||||
strings.NewReplacer("attn.qkv", "attn_q"),
|
||||
strings.NewReplacer("attn.qkv", "attn_k"),
|
||||
strings.NewReplacer("attn.qkv", "attn_v"),
|
||||
))...)
|
||||
} else {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *qwen25VLModel) Replacements() []string {
|
||||
return append(
|
||||
p.qwen2Model.Replacements(),
|
||||
"visual", "v",
|
||||
"blocks", "blk",
|
||||
"attn.proj", "attn_out",
|
||||
"norm1", "ln1",
|
||||
"norm2", "ln2",
|
||||
)
|
||||
}
|
@ -11,7 +11,6 @@ import (
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
@ -48,7 +47,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
|
||||
}
|
||||
t.Cleanup(func() { r.Close() })
|
||||
|
||||
m, _, err := ggml.Decode(r, math.MaxInt)
|
||||
m, _, err := ggml.Decode(r, -1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -131,6 +130,7 @@ func TestConvertModel(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer expectFile.Close()
|
||||
|
||||
var expect map[string]string
|
||||
if err := json.NewDecoder(expectFile).Decode(&expect); err != nil {
|
||||
@ -332,7 +332,7 @@ func TestConvertAdapter(t *testing.T) {
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
m, _, err := ggml.Decode(r, math.MaxInt)
|
||||
m, _, err := ggml.Decode(r, -1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -1,58 +0,0 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"errors"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
type ZipReader struct {
|
||||
r *zip.Reader
|
||||
p string
|
||||
|
||||
// limit is the maximum size of a file that can be read directly
|
||||
// from the zip archive. Files larger than this size will be extracted
|
||||
limit int64
|
||||
}
|
||||
|
||||
func NewZipReader(r *zip.Reader, p string, limit int64) fs.FS {
|
||||
return &ZipReader{r, p, limit}
|
||||
}
|
||||
|
||||
func (z *ZipReader) Open(name string) (fs.File, error) {
|
||||
r, err := z.r.Open(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
if fi, err := r.Stat(); err != nil {
|
||||
return nil, err
|
||||
} else if fi.Size() < z.limit {
|
||||
return r, nil
|
||||
}
|
||||
|
||||
if !filepath.IsLocal(name) {
|
||||
return nil, zip.ErrInsecurePath
|
||||
}
|
||||
|
||||
n := filepath.Join(z.p, name)
|
||||
if _, err := os.Stat(n); errors.Is(err, os.ErrNotExist) {
|
||||
w, err := os.Create(n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer w.Close()
|
||||
|
||||
if _, err := io.Copy(w, r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return os.Open(n)
|
||||
}
|
@ -11,14 +11,15 @@ type Tensor interface {
|
||||
Name() string
|
||||
Shape() []uint64
|
||||
Kind() uint32
|
||||
SetRepacker(repacker)
|
||||
SetRepacker(Repacker)
|
||||
WriteTo(io.Writer) (int64, error)
|
||||
Clone() Tensor
|
||||
}
|
||||
|
||||
type tensorBase struct {
|
||||
name string
|
||||
shape []uint64
|
||||
repacker
|
||||
name string
|
||||
shape []uint64
|
||||
repacker Repacker
|
||||
}
|
||||
|
||||
func (t tensorBase) Name() string {
|
||||
@ -36,7 +37,11 @@ const (
|
||||
|
||||
func (t tensorBase) Kind() uint32 {
|
||||
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
|
||||
t.name == "token_types.weight" {
|
||||
t.name == "token_types.weight" ||
|
||||
t.name == "v.positional_embedding_vlm" ||
|
||||
t.name == "v.tile_position_embd.weight" ||
|
||||
t.name == "v.pre_tile_position_embd.weight" ||
|
||||
t.name == "v.post_tile_position_embd.weight" {
|
||||
// these tensors are always F32
|
||||
return 0
|
||||
}
|
||||
@ -51,21 +56,18 @@ func (t tensorBase) Kind() uint32 {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tensorBase) SetRepacker(fn repacker) {
|
||||
func (t *tensorBase) SetRepacker(fn Repacker) {
|
||||
t.repacker = fn
|
||||
}
|
||||
|
||||
type repacker func(string, []float32, []uint64) ([]float32, error)
|
||||
type Repacker func(string, []float32, []uint64) ([]float32, error)
|
||||
|
||||
func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) {
|
||||
patterns := []struct {
|
||||
Pattern string
|
||||
Func func(fs.FS, *strings.Replacer, ...string) ([]Tensor, error)
|
||||
}{
|
||||
{"model-*-of-*.safetensors", parseSafetensors},
|
||||
{"model.safetensors", parseSafetensors},
|
||||
{"adapters.safetensors", parseSafetensors},
|
||||
{"adapter_model.safetensors", parseSafetensors},
|
||||
{"*.safetensors", parseSafetensors},
|
||||
{"pytorch_model-*-of-*.bin", parseTorch},
|
||||
{"pytorch_model.bin", parseTorch},
|
||||
{"consolidated.*.pth", parseTorch},
|
||||
|
@ -94,6 +94,21 @@ type safetensor struct {
|
||||
*tensorBase
|
||||
}
|
||||
|
||||
func (st safetensor) Clone() Tensor {
|
||||
return &safetensor{
|
||||
fs: st.fs,
|
||||
path: st.path,
|
||||
dtype: st.dtype,
|
||||
offset: st.offset,
|
||||
size: st.size,
|
||||
tensorBase: &tensorBase{
|
||||
name: st.name,
|
||||
repacker: st.repacker,
|
||||
shape: slices.Clone(st.shape),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
||||
f, err := st.fs.Open(st.path)
|
||||
if err != nil {
|
||||
|
@ -43,6 +43,17 @@ type torch struct {
|
||||
*tensorBase
|
||||
}
|
||||
|
||||
func (t torch) Clone() Tensor {
|
||||
return torch{
|
||||
storage: t.storage,
|
||||
tensorBase: &tensorBase{
|
||||
name: t.name,
|
||||
shape: t.shape,
|
||||
repacker: t.repacker,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (pt torch) WriteTo(w io.Writer) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
@ -1360,7 +1360,7 @@ func file_sentencepiece_model_proto_rawDescGZIP() []byte {
|
||||
|
||||
var file_sentencepiece_model_proto_enumTypes = make([]protoimpl.EnumInfo, 2)
|
||||
var file_sentencepiece_model_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
|
||||
var file_sentencepiece_model_proto_goTypes = []interface{}{
|
||||
var file_sentencepiece_model_proto_goTypes = []any{
|
||||
(TrainerSpec_ModelType)(0), // 0: sentencepiece.TrainerSpec.ModelType
|
||||
(ModelProto_SentencePiece_Type)(0), // 1: sentencepiece.ModelProto.SentencePiece.Type
|
||||
(*TrainerSpec)(nil), // 2: sentencepiece.TrainerSpec
|
||||
@ -1392,7 +1392,7 @@ func file_sentencepiece_model_proto_init() {
|
||||
return
|
||||
}
|
||||
if !protoimpl.UnsafeEnabled {
|
||||
file_sentencepiece_model_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
||||
file_sentencepiece_model_proto_msgTypes[0].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*TrainerSpec); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
@ -1406,7 +1406,7 @@ func file_sentencepiece_model_proto_init() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_sentencepiece_model_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
|
||||
file_sentencepiece_model_proto_msgTypes[1].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*NormalizerSpec); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
@ -1420,7 +1420,7 @@ func file_sentencepiece_model_proto_init() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_sentencepiece_model_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
|
||||
file_sentencepiece_model_proto_msgTypes[2].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*SelfTestData); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
@ -1434,7 +1434,7 @@ func file_sentencepiece_model_proto_init() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_sentencepiece_model_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
|
||||
file_sentencepiece_model_proto_msgTypes[3].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*ModelProto); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
@ -1448,7 +1448,7 @@ func file_sentencepiece_model_proto_init() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_sentencepiece_model_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
|
||||
file_sentencepiece_model_proto_msgTypes[4].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*SelfTestData_Sample); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
@ -1460,7 +1460,7 @@ func file_sentencepiece_model_proto_init() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_sentencepiece_model_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} {
|
||||
file_sentencepiece_model_proto_msgTypes[5].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*ModelProto_SentencePiece); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
|
56
convert/tensor.go
Normal file
56
convert/tensor.go
Normal file
@ -0,0 +1,56 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"iter"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
)
|
||||
|
||||
// splitDim splits a tensor along a specified dimension into multiple tensors. The dimension
|
||||
// is split evenly based on the number of replacers provided.
|
||||
func splitDim(t Tensor, dim int, replacers ...*strings.Replacer) iter.Seq[*ggml.Tensor] {
|
||||
return func(yield func(*ggml.Tensor) bool) {
|
||||
for i, replacer := range replacers {
|
||||
shape := slices.Clone(t.Shape())
|
||||
shape[dim] = shape[dim] / uint64(len(replacers))
|
||||
|
||||
slice := slices.Repeat([]tensor.Slice{nil}, len(shape))
|
||||
slice[dim] = tensor.S(i*int(shape[dim]), (i+1)*int(shape[dim]))
|
||||
|
||||
tt := t.Clone()
|
||||
tt.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
|
||||
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
t, err := t.Slice(slice...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t = tensor.Materialize(t)
|
||||
// flatten tensor so it can be written as a vector
|
||||
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return native.VectorF32(t.(*tensor.Dense))
|
||||
})
|
||||
|
||||
if !yield(&ggml.Tensor{
|
||||
Name: replacer.Replace(t.Name()),
|
||||
Kind: t.Kind(),
|
||||
Shape: shape,
|
||||
WriterTo: tt,
|
||||
}) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -12,7 +12,7 @@ func IsNUMA() bool {
|
||||
// numa support in llama.cpp is linux only
|
||||
return false
|
||||
}
|
||||
ids := map[string]interface{}{}
|
||||
ids := map[string]any{}
|
||||
packageIds, _ := filepath.Glob("/sys/devices/system/cpu/cpu*/topology/physical_package_id")
|
||||
for _, packageId := range packageIds {
|
||||
id, err := os.ReadFile(packageId)
|
||||
|
@ -670,7 +670,7 @@ func loadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string, e
|
||||
}
|
||||
|
||||
func getVerboseState() C.uint16_t {
|
||||
if envconfig.Debug() {
|
||||
if envconfig.LogLevel() < slog.LevelInfo {
|
||||
return C.uint16_t(1)
|
||||
}
|
||||
return C.uint16_t(0)
|
||||
|
@ -27,12 +27,14 @@
|
||||
|
||||
#endif
|
||||
|
||||
#ifndef LOG
|
||||
#define LOG(verbose, ...) \
|
||||
do { \
|
||||
if (verbose) { \
|
||||
fprintf(stderr, __VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
@ -1,6 +1,7 @@
|
||||
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
|
||||
|
||||
#include <string.h>
|
||||
#include <inttypes.h>
|
||||
#include "gpu_info_cudart.h"
|
||||
|
||||
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
||||
@ -58,7 +59,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
||||
LOG(resp->ch.verbose, "cudaSetDevice err: %d\n", ret);
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
resp->ch.handle = NULL;
|
||||
if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
|
||||
if (ret == CUDART_ERROR_INSUFFICIENT_DRIVER) {
|
||||
resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama");
|
||||
return;
|
||||
}
|
||||
@ -168,9 +169,9 @@ void cudart_bootstrap(cudart_handle_t h, int i, mem_info_t *resp) {
|
||||
resp->free = memInfo.free;
|
||||
resp->used = memInfo.used;
|
||||
|
||||
LOG(h.verbose, "[%s] CUDA totalMem %lu\n", resp->gpu_id, resp->total);
|
||||
LOG(h.verbose, "[%s] CUDA freeMem %lu\n", resp->gpu_id, resp->free);
|
||||
LOG(h.verbose, "[%s] CUDA usedMem %lu\n", resp->gpu_id, resp->used);
|
||||
LOG(h.verbose, "[%s] CUDA totalMem %" PRId64 "\n", resp->gpu_id, resp->total);
|
||||
LOG(h.verbose, "[%s] CUDA freeMem %" PRId64 "\n", resp->gpu_id, resp->free);
|
||||
LOG(h.verbose, "[%s] CUDA usedMem %" PRId64 "\n", resp->gpu_id, resp->used);
|
||||
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
|
||||
}
|
||||
|
||||
@ -180,4 +181,4 @@ void cudart_release(cudart_handle_t h) {
|
||||
h.handle = NULL;
|
||||
}
|
||||
|
||||
#endif // __APPLE__
|
||||
#endif // __APPLE__
|
||||
|
@ -1,6 +1,7 @@
|
||||
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
|
||||
|
||||
#include <string.h>
|
||||
#include <inttypes.h>
|
||||
#include "gpu_info_nvcuda.h"
|
||||
|
||||
void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
|
||||
@ -193,8 +194,8 @@ void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) {
|
||||
resp->total = memInfo.total;
|
||||
resp->free = memInfo.free;
|
||||
|
||||
LOG(h.verbose, "[%s] CUDA totalMem %lu mb\n", resp->gpu_id, resp->total / 1024 / 1024);
|
||||
LOG(h.verbose, "[%s] CUDA freeMem %lu mb\n", resp->gpu_id, resp->free / 1024 / 1024);
|
||||
LOG(h.verbose, "[%s] CUDA totalMem %" PRId64 "mb\n", resp->gpu_id, resp->total / 1024 / 1024);
|
||||
LOG(h.verbose, "[%s] CUDA freeMem %" PRId64 "mb\n", resp->gpu_id, resp->free / 1024 / 1024);
|
||||
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
|
||||
|
||||
|
||||
@ -247,4 +248,4 @@ void nvcuda_release(nvcuda_handle_t h) {
|
||||
h.handle = NULL;
|
||||
}
|
||||
|
||||
#endif // __APPLE__
|
||||
#endif // __APPLE__
|
||||
|
@ -111,6 +111,7 @@ func GetCPUDetails() ([]CPU, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
return linuxCPUDetails(file)
|
||||
}
|
||||
|
||||
@ -168,13 +169,11 @@ func linuxCPUDetails(file io.Reader) ([]CPU, error) {
|
||||
for id, s := range socketByID {
|
||||
s.CoreCount = len(coreBySocket[id])
|
||||
s.ThreadCount = 0
|
||||
for _, tc := range threadsByCoreBySocket[id] {
|
||||
s.ThreadCount += tc
|
||||
}
|
||||
|
||||
// This only works if HT is enabled, consider a more reliable model, maybe cache size comparisons?
|
||||
efficiencyCoreCount := 0
|
||||
for _, threads := range threadsByCoreBySocket[id] {
|
||||
s.ThreadCount += threads
|
||||
if threads == 1 {
|
||||
efficiencyCoreCount++
|
||||
}
|
||||
|
89
docs/api.md
89
docs/api.md
@ -19,7 +19,7 @@
|
||||
|
||||
### Model names
|
||||
|
||||
Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q4_1` and `llama3:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version.
|
||||
Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q8_0` and `llama3:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version.
|
||||
|
||||
### Durations
|
||||
|
||||
@ -173,7 +173,7 @@ curl http://localhost:11434/api/generate -d '{
|
||||
|
||||
##### Response
|
||||
|
||||
```json
|
||||
```json5
|
||||
{
|
||||
"model": "codellama:code",
|
||||
"created_at": "2024-07-22T20:47:51.147561Z",
|
||||
@ -394,9 +394,6 @@ curl http://localhost:11434/api/generate -d '{
|
||||
"repeat_penalty": 1.2,
|
||||
"presence_penalty": 1.5,
|
||||
"frequency_penalty": 1.0,
|
||||
"mirostat": 1,
|
||||
"mirostat_tau": 0.8,
|
||||
"mirostat_eta": 0.6,
|
||||
"penalize_newline": true,
|
||||
"stop": ["\n", "user:"],
|
||||
"numa": false,
|
||||
@ -404,10 +401,7 @@ curl http://localhost:11434/api/generate -d '{
|
||||
"num_batch": 2,
|
||||
"num_gpu": 1,
|
||||
"main_gpu": 0,
|
||||
"low_vram": false,
|
||||
"vocab_only": false,
|
||||
"use_mmap": true,
|
||||
"use_mlock": false,
|
||||
"num_thread": 8
|
||||
}
|
||||
}'
|
||||
@ -558,6 +552,10 @@ Final response:
|
||||
{
|
||||
"model": "llama3.2",
|
||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": ""
|
||||
},
|
||||
"done": true,
|
||||
"total_duration": 4883583458,
|
||||
"load_duration": 1334875,
|
||||
@ -954,19 +952,8 @@ If you are creating a model from a safetensors directory or from a GGUF file, yo
|
||||
|
||||
| Type | Recommended |
|
||||
| --- | :-: |
|
||||
| q2_K | |
|
||||
| q3_K_L | |
|
||||
| q3_K_M | |
|
||||
| q3_K_S | |
|
||||
| q4_0 | |
|
||||
| q4_1 | |
|
||||
| q4_K_M | * |
|
||||
| q4_K_S | |
|
||||
| q5_0 | |
|
||||
| q5_1 | |
|
||||
| q5_K_M | |
|
||||
| q5_K_S | |
|
||||
| q6_K | |
|
||||
| q8_0 | * |
|
||||
|
||||
### Examples
|
||||
@ -1011,8 +998,8 @@ Quantize a non-quantized model.
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/create -d '{
|
||||
"model": "llama3.1:quantized",
|
||||
"from": "llama3.1:8b-instruct-fp16",
|
||||
"model": "llama3.2:quantized",
|
||||
"from": "llama3.2:3b-instruct-fp16",
|
||||
"quantize": "q4_K_M"
|
||||
}'
|
||||
```
|
||||
@ -1022,12 +1009,14 @@ curl http://localhost:11434/api/create -d '{
|
||||
A stream of JSON objects is returned:
|
||||
|
||||
```json
|
||||
{"status":"quantizing F16 model to Q4_K_M"}
|
||||
{"status":"creating new layer sha256:667b0c1932bc6ffc593ed1d03f895bf2dc8dc6df21db3042284a6f4416b06a29"}
|
||||
{"status":"using existing layer sha256:11ce4ee3e170f6adebac9a991c22e22ab3f8530e154ee669954c4bc73061c258"}
|
||||
{"status":"using existing layer sha256:0ba8f0e314b4264dfd19df045cde9d4c394a52474bf92ed6a3de22a4ca31a177"}
|
||||
{"status":"quantizing F16 model to Q4_K_M","digest":"0","total":6433687776,"completed":12302}
|
||||
{"status":"quantizing F16 model to Q4_K_M","digest":"0","total":6433687776,"completed":6433687552}
|
||||
{"status":"verifying conversion"}
|
||||
{"status":"creating new layer sha256:fb7f4f211b89c6c4928ff4ddb73db9f9c0cfca3e000c3e40d6cf27ddc6ca72eb"}
|
||||
{"status":"using existing layer sha256:966de95ca8a62200913e3f8bfbf84c8494536f1b94b49166851e76644e966396"}
|
||||
{"status":"using existing layer sha256:fcc5a6bec9daf9b561a68827b67ab6088e1dba9d1fa2a50d7bbcc8384e0a265d"}
|
||||
{"status":"using existing layer sha256:a70ff7e570d97baaf4e62ac6e6ad9975e04caa6d900d3742d37698494479e0cd"}
|
||||
{"status":"using existing layer sha256:56bb8bd477a519ffa694fc449c2413c6f0e1d3b1c88fa7e3c9d88d3ae49d4dcb"}
|
||||
{"status":"creating new layer sha256:455f34728c9b5dd3376378bfb809ee166c145b0b4c1f1a6feca069055066ef9a"}
|
||||
{"status":"writing manifest"}
|
||||
{"status":"success"}
|
||||
```
|
||||
@ -1165,29 +1154,37 @@ A single JSON object will be returned.
|
||||
{
|
||||
"models": [
|
||||
{
|
||||
"name": "codellama:13b",
|
||||
"modified_at": "2023-11-04T14:56:49.277302595-07:00",
|
||||
"size": 7365960935,
|
||||
"digest": "9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697",
|
||||
"name": "deepseek-r1:latest",
|
||||
"model": "deepseek-r1:latest",
|
||||
"modified_at": "2025-05-10T08:06:48.639712648-07:00",
|
||||
"size": 4683075271,
|
||||
"digest": "0a8c266910232fd3291e71e5ba1e058cc5af9d411192cf88b6d30e92b6e73163",
|
||||
"details": {
|
||||
"parent_model": "",
|
||||
"format": "gguf",
|
||||
"family": "llama",
|
||||
"families": null,
|
||||
"parameter_size": "13B",
|
||||
"quantization_level": "Q4_0"
|
||||
"family": "qwen2",
|
||||
"families": [
|
||||
"qwen2"
|
||||
],
|
||||
"parameter_size": "7.6B",
|
||||
"quantization_level": "Q4_K_M"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "llama3:latest",
|
||||
"modified_at": "2023-12-07T09:32:18.757212583-08:00",
|
||||
"size": 3825819519,
|
||||
"digest": "fe938a131f40e6f6d40083c9f0f430a515233eb2edaa6d72eb85c50d64f2300e",
|
||||
"name": "llama3.2:latest",
|
||||
"model": "llama3.2:latest",
|
||||
"modified_at": "2025-05-04T17:37:44.706015396-07:00",
|
||||
"size": 2019393189,
|
||||
"digest": "a80c4f17acd55265feec403c7aef86be0c25983ab279d83f3bcd3abbcb5b8b72",
|
||||
"details": {
|
||||
"parent_model": "",
|
||||
"format": "gguf",
|
||||
"family": "llama",
|
||||
"families": null,
|
||||
"parameter_size": "7B",
|
||||
"quantization_level": "Q4_0"
|
||||
"families": [
|
||||
"llama"
|
||||
],
|
||||
"parameter_size": "3.2B",
|
||||
"quantization_level": "Q4_K_M"
|
||||
}
|
||||
}
|
||||
]
|
||||
@ -1213,13 +1210,13 @@ Show information about a model including details, modelfile, template, parameter
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/show -d '{
|
||||
"model": "llama3.2"
|
||||
"model": "llava"
|
||||
}'
|
||||
```
|
||||
|
||||
#### Response
|
||||
|
||||
```json
|
||||
```json5
|
||||
{
|
||||
"modelfile": "# Modelfile generated by \"ollama show\"\n# To build a new Modelfile based on this one, replace the FROM line with:\n# FROM llava:latest\n\nFROM /Users/matt/.ollama/models/blobs/sha256:200765e1283640ffbd013184bf496e261032fa75b99498a9613be4e94d63ad52\nTEMPLATE \"\"\"{{ .System }}\nUSER: {{ .Prompt }}\nASSISTANT: \"\"\"\nPARAMETER num_ctx 4096\nPARAMETER stop \"\u003c/s\u003e\"\nPARAMETER stop \"USER:\"\nPARAMETER stop \"ASSISTANT:\"",
|
||||
"parameters": "num_keep 24\nstop \"<|start_header_id|>\"\nstop \"<|end_header_id|>\"\nstop \"<|eot_id|>\"",
|
||||
@ -1256,7 +1253,11 @@ curl http://localhost:11434/api/show -d '{
|
||||
"tokenizer.ggml.pre": "llama-bpe",
|
||||
"tokenizer.ggml.token_type": [], // populates if `verbose=true`
|
||||
"tokenizer.ggml.tokens": [] // populates if `verbose=true`
|
||||
}
|
||||
},
|
||||
"capabilities": [
|
||||
"completion",
|
||||
"vision"
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
|
59
docs/benchmark.md
Normal file
59
docs/benchmark.md
Normal file
@ -0,0 +1,59 @@
|
||||
# Benchmark
|
||||
|
||||
Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes.
|
||||
|
||||
## When to use
|
||||
|
||||
Run these benchmarks when:
|
||||
- Making changes to the model inference engine
|
||||
- Modifying model loading/unloading logic
|
||||
- Changing prompt processing or token generation code
|
||||
- Implementing a new model architecture
|
||||
- Testing performance across different hardware setups
|
||||
|
||||
## Prerequisites
|
||||
- Ollama server running locally with `ollama serve` on `127.0.0.1:11434`
|
||||
## Usage and Examples
|
||||
|
||||
>[!NOTE]
|
||||
>All commands must be run from the root directory of the Ollama project.
|
||||
|
||||
Basic syntax:
|
||||
```bash
|
||||
go test -bench=. ./benchmark/... -m $MODEL_NAME
|
||||
```
|
||||
|
||||
Required flags:
|
||||
- `-bench=.`: Run all benchmarks
|
||||
- `-m`: Model name to benchmark
|
||||
|
||||
Optional flags:
|
||||
- `-count N`: Number of times to run the benchmark (useful for statistical analysis)
|
||||
- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes)
|
||||
|
||||
Common usage patterns:
|
||||
|
||||
Single benchmark run with a model specified:
|
||||
```bash
|
||||
go test -bench=. ./benchmark/... -m llama3.3
|
||||
```
|
||||
|
||||
## Output metrics
|
||||
|
||||
The benchmark reports several key metrics:
|
||||
|
||||
- `gen_tok/s`: Generated tokens per second
|
||||
- `prompt_tok/s`: Prompt processing tokens per second
|
||||
- `ttft_ms`: Time to first token in milliseconds
|
||||
- `load_ms`: Model load time in milliseconds
|
||||
- `gen_tokens`: Total tokens generated
|
||||
- `prompt_tokens`: Total prompt tokens processed
|
||||
|
||||
Each benchmark runs two scenarios:
|
||||
- Cold start: Model is loaded from disk for each test
|
||||
- Warm start: Model is pre-loaded in memory
|
||||
|
||||
Three prompt lengths are tested for each scenario:
|
||||
- Short prompt (100 tokens)
|
||||
- Medium prompt (500 tokens)
|
||||
- Long prompt (1000 tokens)
|
@ -20,7 +20,13 @@ Please refer to the [GPU docs](./gpu.md).
|
||||
|
||||
## How can I specify the context window size?
|
||||
|
||||
By default, Ollama uses a context window size of 2048 tokens. This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context length to 8K, use: `OLLAMA_CONTEXT_LENGTH=8192 ollama serve`.
|
||||
By default, Ollama uses a context window size of 4096 tokens.
|
||||
|
||||
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
|
||||
|
||||
```shell
|
||||
OLLAMA_CONTEXT_LENGTH=8192 ollama serve
|
||||
```
|
||||
|
||||
To change this when using `ollama run`, use `/set parameter`:
|
||||
|
||||
|
@ -150,9 +150,6 @@ PARAMETER <parameter> <parametervalue>
|
||||
|
||||
| Parameter | Description | Value Type | Example Usage |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
|
||||
| mirostat | Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | int | mirostat 0 |
|
||||
| mirostat_eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | float | mirostat_eta 0.1 |
|
||||
| mirostat_tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | float | mirostat_tau 5.0 |
|
||||
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
|
||||
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
||||
|
@ -12,7 +12,7 @@ A basic Go template consists of three main parts:
|
||||
|
||||
Here's an example of a simple chat template:
|
||||
|
||||
```gotmpl
|
||||
```go
|
||||
{{- range .Messages }}
|
||||
{{ .Role }}: {{ .Content }}
|
||||
{{- end }}
|
||||
@ -162,6 +162,6 @@ CodeLlama [7B](https://ollama.com/library/codellama:7b-code) and [13B](https://o
|
||||
|
||||
Codestral [22B](https://ollama.com/library/codestral:22b) supports fill-in-middle.
|
||||
|
||||
```gotmpl
|
||||
```go
|
||||
[SUFFIX]{{ .Suffix }}[PREFIX] {{ .Prompt }}
|
||||
```
|
||||
|
@ -9,7 +9,7 @@ cat ~/.ollama/logs/server.log
|
||||
On **Linux** systems with systemd, the logs can be found with this command:
|
||||
|
||||
```shell
|
||||
journalctl -u ollama --no-pager
|
||||
journalctl -u ollama --no-pager --follow --pager-end
|
||||
```
|
||||
|
||||
When you run Ollama in a **container**, the logs go to stdout/stderr in the container:
|
||||
@ -26,7 +26,6 @@ When you run Ollama on **Windows**, there are a few different locations. You can
|
||||
- `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log`
|
||||
- `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH)
|
||||
- `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored
|
||||
- `explorer %TEMP%` where temporary executable files are stored in one or more `ollama*` directories
|
||||
|
||||
To enable additional debug logging to help troubleshoot problems, first **Quit the running app from the tray menu** then in a powershell terminal
|
||||
|
||||
@ -69,10 +68,6 @@ If you run into problems on Linux and want to install an older version, or you'd
|
||||
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.5.7 sh
|
||||
```
|
||||
|
||||
## Linux tmp noexec
|
||||
|
||||
If your system is configured with the "noexec" flag where Ollama stores its temporary executable files, you can specify an alternate location by setting OLLAMA_TMPDIR to a location writable by the user ollama runs as. For example OLLAMA_TMPDIR=/usr/share/ollama/
|
||||
|
||||
## Linux docker
|
||||
|
||||
If Ollama initially works on the GPU in a docker container, but then switches to running on CPU after some period of time with errors in the server log reporting GPU discovery failures, this can be resolved by disabling systemd cgroup management in Docker. Edit `/etc/docker/daemon.json` on the host and add `"exec-opts": ["native.cgroupdriver=cgroupfs"]` to the docker configuration.
|
||||
|
@ -62,7 +62,6 @@ the explorer window by hitting `<Ctrl>+R` and type in:
|
||||
- *upgrade.log* contains log output for upgrades
|
||||
- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
|
||||
- `explorer %HOMEPATH%\.ollama` contains models and configuration
|
||||
- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories
|
||||
|
||||
## Uninstall
|
||||
|
||||
|
@ -149,9 +149,22 @@ func Bool(k string) func() bool {
|
||||
}
|
||||
}
|
||||
|
||||
// LogLevel returns the log level for the application.
|
||||
// Values are 0 or false INFO (Default), 1 or true DEBUG, 2 TRACE
|
||||
func LogLevel() slog.Level {
|
||||
level := slog.LevelInfo
|
||||
if s := Var("OLLAMA_DEBUG"); s != "" {
|
||||
if b, _ := strconv.ParseBool(s); b {
|
||||
level = slog.LevelDebug
|
||||
} else if i, _ := strconv.ParseInt(s, 10, 64); i != 0 {
|
||||
level = slog.Level(i * -4)
|
||||
}
|
||||
}
|
||||
|
||||
return level
|
||||
}
|
||||
|
||||
var (
|
||||
// Debug enabled additional debug information.
|
||||
Debug = Bool("OLLAMA_DEBUG")
|
||||
// FlashAttention enables the experimental flash attention feature.
|
||||
FlashAttention = Bool("OLLAMA_FLASH_ATTENTION")
|
||||
// KvCacheType is the quantization type for the K/V cache.
|
||||
@ -169,7 +182,7 @@ var (
|
||||
// Enable the new Ollama engine
|
||||
NewEngine = Bool("OLLAMA_NEW_ENGINE")
|
||||
// ContextLength sets the default context length
|
||||
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 2048)
|
||||
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
|
||||
)
|
||||
|
||||
func String(s string) func() string {
|
||||
@ -209,8 +222,6 @@ var (
|
||||
MaxRunners = Uint("OLLAMA_MAX_LOADED_MODELS", 0)
|
||||
// MaxQueue sets the maximum number of queued requests. MaxQueue can be configured via the OLLAMA_MAX_QUEUE environment variable.
|
||||
MaxQueue = Uint("OLLAMA_MAX_QUEUE", 512)
|
||||
// MaxVRAM sets a maximum VRAM override in bytes. MaxVRAM can be configured via the OLLAMA_MAX_VRAM environment variable.
|
||||
MaxVRAM = Uint("OLLAMA_MAX_VRAM", 0)
|
||||
)
|
||||
|
||||
func Uint64(key string, defaultValue uint64) func() uint64 {
|
||||
@ -238,7 +249,7 @@ type EnvVar struct {
|
||||
|
||||
func AsMap() map[string]EnvVar {
|
||||
ret := map[string]EnvVar{
|
||||
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
||||
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
||||
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"},
|
||||
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
|
||||
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
|
||||
@ -255,7 +266,7 @@ func AsMap() map[string]EnvVar {
|
||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
|
||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 2048)"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
|
||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||
|
||||
// Informational
|
||||
|
@ -1,11 +1,13 @@
|
||||
package envconfig
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
func TestHost(t *testing.T) {
|
||||
@ -279,8 +281,8 @@ func TestVar(t *testing.T) {
|
||||
|
||||
func TestContextLength(t *testing.T) {
|
||||
cases := map[string]uint{
|
||||
"": 2048,
|
||||
"4096": 4096,
|
||||
"": 4096,
|
||||
"2048": 2048,
|
||||
}
|
||||
|
||||
for k, v := range cases {
|
||||
@ -292,3 +294,34 @@ func TestContextLength(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogLevel(t *testing.T) {
|
||||
cases := map[string]slog.Level{
|
||||
// Default to INFO
|
||||
"": slog.LevelInfo,
|
||||
"false": slog.LevelInfo,
|
||||
"f": slog.LevelInfo,
|
||||
"0": slog.LevelInfo,
|
||||
|
||||
// True values enable Debug
|
||||
"true": slog.LevelDebug,
|
||||
"t": slog.LevelDebug,
|
||||
|
||||
// Positive values increase verbosity
|
||||
"1": slog.LevelDebug,
|
||||
"2": logutil.LevelTrace,
|
||||
|
||||
// Negative values decrease verbosity
|
||||
"-1": slog.LevelWarn,
|
||||
"-2": slog.LevelError,
|
||||
}
|
||||
|
||||
for k, v := range cases {
|
||||
t.Run(k, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_DEBUG", k)
|
||||
if i := LogLevel(); i != v {
|
||||
t.Errorf("%s: expected %d, got %d", k, v, i)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -5,7 +5,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func assertEqual(t *testing.T, a interface{}, b interface{}) {
|
||||
func assertEqual(t *testing.T, a any, b any) {
|
||||
if a != b {
|
||||
t.Errorf("Assert failed, expected %v, got %v", b, a)
|
||||
}
|
||||
|
13
fs/config.go
Normal file
13
fs/config.go
Normal file
@ -0,0 +1,13 @@
|
||||
package fs
|
||||
|
||||
type Config interface {
|
||||
Architecture() string
|
||||
String(string, ...string) string
|
||||
Uint(string, ...uint32) uint32
|
||||
Float(string, ...float32) float32
|
||||
Bool(string, ...bool) bool
|
||||
|
||||
Strings(string, ...[]string) []string
|
||||
Ints(string, ...[]int32) []int32
|
||||
Floats(string, ...[]float32) []float32
|
||||
}
|
218
fs/ggml/ggml.go
218
fs/ggml/ggml.go
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
@ -33,15 +34,15 @@ func (kv KV) Kind() string {
|
||||
}
|
||||
|
||||
func (kv KV) ParameterCount() uint64 {
|
||||
return keyValue[uint64](kv, "general.parameter_count")
|
||||
return keyValue(kv, "general.parameter_count", uint64(0))
|
||||
}
|
||||
|
||||
func (kv KV) FileType() fileType {
|
||||
func (kv KV) FileType() FileType {
|
||||
if t := kv.Uint("general.file_type"); t > 0 {
|
||||
return fileType(t)
|
||||
return FileType(t)
|
||||
}
|
||||
|
||||
return fileTypeUnknown
|
||||
return FileTypeUnknown
|
||||
}
|
||||
|
||||
func (kv KV) BlockCount() uint64 {
|
||||
@ -105,39 +106,44 @@ func (kv KV) Bool(key string, defaultValue ...bool) bool {
|
||||
}
|
||||
|
||||
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
||||
r := keyValue(kv, key, &array{})
|
||||
s := make([]string, r.size)
|
||||
for i := range r.size {
|
||||
s[i] = r.values[i].(string)
|
||||
}
|
||||
return keyValue(kv, key, &array[string]{values: append(defaultValue, []string(nil))[0]}).values
|
||||
}
|
||||
|
||||
return s
|
||||
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
|
||||
return keyValue(kv, key, &array[int32]{values: append(defaultValue, []int32(nil))[0]}).values
|
||||
}
|
||||
|
||||
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
|
||||
r := keyValue(kv, key, &array{})
|
||||
s := make([]uint32, r.size)
|
||||
for i := range r.size {
|
||||
s[i] = uint32(r.values[i].(int32))
|
||||
}
|
||||
|
||||
return s
|
||||
return keyValue(kv, key, &array[uint32]{values: append(defaultValue, []uint32(nil))[0]}).values
|
||||
}
|
||||
|
||||
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
|
||||
r := keyValue(kv, key, &array{})
|
||||
s := make([]float32, r.size)
|
||||
for i := range r.size {
|
||||
s[i] = float32(r.values[i].(float32))
|
||||
}
|
||||
return s
|
||||
return keyValue(kv, key, &array[float32]{values: append(defaultValue, []float32(nil))[0]}).values
|
||||
}
|
||||
|
||||
func (kv KV) OllamaEngineRequired() bool {
|
||||
return kv.Architecture() == "gemma3"
|
||||
return slices.Contains([]string{
|
||||
"gemma3",
|
||||
"mistral3",
|
||||
"llama4",
|
||||
"mllama",
|
||||
"qwen25vl",
|
||||
}, kv.Architecture())
|
||||
}
|
||||
|
||||
func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T {
|
||||
type valueTypes interface {
|
||||
uint8 | int8 | uint16 | int16 |
|
||||
uint32 | int32 | uint64 | int64 |
|
||||
string | float32 | float64 | bool
|
||||
}
|
||||
|
||||
type arrayValueTypes interface {
|
||||
*array[uint8] | *array[int8] | *array[uint16] | *array[int16] |
|
||||
*array[uint32] | *array[int32] | *array[uint64] | *array[int64] |
|
||||
*array[string] | *array[float32] | *array[float64] | *array[bool]
|
||||
}
|
||||
|
||||
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) T {
|
||||
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
|
||||
key = kv.Architecture() + "." + key
|
||||
}
|
||||
@ -146,7 +152,7 @@ func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key s
|
||||
return val.(T)
|
||||
}
|
||||
|
||||
slog.Warn("key not found", "key", key, "default", defaultValue[0])
|
||||
slog.Debug("key not found", "key", key, "default", defaultValue[0])
|
||||
return defaultValue[0]
|
||||
}
|
||||
|
||||
@ -223,7 +229,11 @@ func (t Tensor) block() (n int) {
|
||||
}
|
||||
|
||||
func (t Tensor) blockSize() uint64 {
|
||||
switch t.Kind {
|
||||
return (TensorType)(t.Kind).BlockSize()
|
||||
}
|
||||
|
||||
func (t TensorType) BlockSize() uint64 {
|
||||
switch t {
|
||||
case
|
||||
0, // F32
|
||||
1, // F16
|
||||
@ -249,73 +259,77 @@ func (t Tensor) blockSize() uint64 {
|
||||
}
|
||||
|
||||
func (t Tensor) typeSize() uint64 {
|
||||
blockSize := t.blockSize()
|
||||
return TensorType(t.Kind).TypeSize()
|
||||
}
|
||||
|
||||
switch t.Kind {
|
||||
case 0: // FP32
|
||||
func (t TensorType) TypeSize() uint64 {
|
||||
blockSize := t.BlockSize()
|
||||
|
||||
switch t {
|
||||
case TensorTypeF32:
|
||||
return 4
|
||||
case 1: // FP16
|
||||
case TensorTypeF16:
|
||||
return 2
|
||||
case 2: // Q4_0
|
||||
case TensorTypeQ4_0:
|
||||
return 2 + blockSize/2
|
||||
case 3: // Q4_1
|
||||
case TensorTypeQ4_1:
|
||||
return 2 + 2 + blockSize/2
|
||||
case 6: // Q5_0
|
||||
case TensorTypeQ5_0:
|
||||
return 2 + 4 + blockSize/2
|
||||
case 7: // Q5_1
|
||||
case TensorTypeQ5_1:
|
||||
return 2 + 2 + 4 + blockSize/2
|
||||
case 8: // Q8_0
|
||||
case TensorTypeQ8_0:
|
||||
return 2 + blockSize
|
||||
case 9: // Q8_1
|
||||
case TensorTypeQ8_1:
|
||||
return 2 + 2 + blockSize
|
||||
case 10: // Q2_K
|
||||
case TensorTypeQ2_K:
|
||||
return blockSize/16 + blockSize/4 + 2 + 2
|
||||
case 11: // Q3_K
|
||||
case TensorTypeQ3_K:
|
||||
return blockSize/8 + blockSize/4 + 12 + 2
|
||||
case 12: // Q4_K
|
||||
case TensorTypeQ4_K:
|
||||
return 2 + 2 + 12 + blockSize/2
|
||||
case 13: // Q5_K
|
||||
case TensorTypeQ5_K:
|
||||
return 2 + 2 + 12 + blockSize/8 + blockSize/2
|
||||
case 14: // Q6_K
|
||||
case TensorTypeQ6_K:
|
||||
return blockSize/2 + blockSize/4 + blockSize/16 + 2
|
||||
case 15: // Q8_K
|
||||
case TensorTypeQ8_K:
|
||||
return 4 + blockSize + 2*blockSize/16
|
||||
case 16: // IQ2_XXS
|
||||
case tensorTypeIQ2_XXS:
|
||||
return 2 + 2*blockSize/8
|
||||
case 17: // IQ2_XS
|
||||
case tensorTypeIQ2_XS:
|
||||
return 2 + 2*blockSize/8 + blockSize/32
|
||||
case 18: // IQ3_XXS
|
||||
case tensorTypeIQ3_XXS:
|
||||
return 2 + blockSize/4 + blockSize/8
|
||||
case 19: // IQ1_S
|
||||
case tensorTypeIQ1_S:
|
||||
return 2 + blockSize/8 + blockSize/16
|
||||
case 20: // IQ4_NL
|
||||
case tensorTypeIQ4_NL:
|
||||
return 2 + blockSize/2
|
||||
case 21: // IQ3_S
|
||||
case tensorTypeIQ3_S:
|
||||
return 2 + blockSize/4 + blockSize/8 + blockSize/32 + 4
|
||||
case 22: // IQ2_S
|
||||
case tensorTypeIQ2_S:
|
||||
return 2 + blockSize/4 + blockSize/16
|
||||
case 23: // IQ4_XS
|
||||
case tensorTypeIQ4_XS:
|
||||
return 2 + 2 + blockSize/2 + blockSize/64
|
||||
case 24: // I8
|
||||
case TensorTypeI8:
|
||||
return 1
|
||||
case 25: // I16
|
||||
case TensorTypeI16:
|
||||
return 2
|
||||
case 26: // I32
|
||||
case TensorTypeI32:
|
||||
return 4
|
||||
case 27: // I64
|
||||
case TensorTypeI64:
|
||||
return 8
|
||||
case 28: // F64
|
||||
case TensorTypeF64:
|
||||
return 8
|
||||
case 29: // IQ1_M
|
||||
case tensorTypeIQ1_M:
|
||||
return blockSize/8 + blockSize/16 + blockSize/32
|
||||
case 30: // BF16
|
||||
case TensorTypeBF16:
|
||||
return 2
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func (t Tensor) parameters() uint64 {
|
||||
func (t Tensor) Elements() uint64 {
|
||||
var count uint64 = 1
|
||||
for _, n := range t.Shape {
|
||||
count *= n
|
||||
@ -324,11 +338,11 @@ func (t Tensor) parameters() uint64 {
|
||||
}
|
||||
|
||||
func (t Tensor) Size() uint64 {
|
||||
return t.parameters() * t.typeSize() / t.blockSize()
|
||||
return t.Elements() * t.typeSize() / t.blockSize()
|
||||
}
|
||||
|
||||
func (t Tensor) Type() string {
|
||||
return fileType(t.Kind).String()
|
||||
return TensorType(t.Kind).String()
|
||||
}
|
||||
|
||||
type container interface {
|
||||
@ -372,13 +386,8 @@ func DetectContentType(b []byte) string {
|
||||
// Decode decodes a GGML model from the given reader.
|
||||
//
|
||||
// It collects array values for arrays with a size less than or equal to
|
||||
// maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
|
||||
// the maxArraySize is negative, all arrays are collected.
|
||||
// maxArraySize. If the maxArraySize is negative, all arrays are collected.
|
||||
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
||||
if maxArraySize == 0 {
|
||||
maxArraySize = 1024
|
||||
}
|
||||
|
||||
rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
|
||||
|
||||
var magic uint32
|
||||
@ -413,11 +422,11 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
||||
}, offset, nil
|
||||
}
|
||||
|
||||
func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
embedding := f.KV().EmbeddingLength()
|
||||
heads := f.KV().HeadCount()
|
||||
headsKV := f.KV().HeadCountKV()
|
||||
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array).size)
|
||||
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size)
|
||||
|
||||
embeddingHeads := f.KV().EmbeddingHeadCount()
|
||||
embeddingHeadsK := f.KV().EmbeddingHeadCountK()
|
||||
@ -426,10 +435,13 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
||||
layers := f.Tensors().GroupLayers()
|
||||
|
||||
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
||||
kv = uint64(float64(context*f.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
kv = make([]uint64, f.KV().BlockCount())
|
||||
for i := range kv {
|
||||
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
}
|
||||
|
||||
switch f.KV().Architecture() {
|
||||
case "llama":
|
||||
case "llama", "llama4":
|
||||
fullOffload = max(
|
||||
4*batch*(1+4*embedding+context*(1+heads)),
|
||||
4*batch*(embedding+vocab),
|
||||
@ -443,7 +455,7 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
||||
|
||||
if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
|
||||
// mixtral 8x22b
|
||||
ff := uint64(f.KV()["llama.feed_forward_length"].(uint32))
|
||||
ff := uint64(f.KV().Uint("feed_forward_length"))
|
||||
partialOffload = max(
|
||||
3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
|
||||
4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
|
||||
@ -460,16 +472,14 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
||||
case "mllama":
|
||||
var visionTokens, tiles uint64 = 1601, 4
|
||||
|
||||
if crossAttentionLayers, ok := f.KV()["mllama.attention.cross_attention_layers"].(*array); ok {
|
||||
kv = headsKV *
|
||||
(embeddingHeadsK + embeddingHeadsV) * // one for K, one for V
|
||||
(2* // sizeof(float16)
|
||||
(f.KV().BlockCount()-uint64(crossAttentionLayers.size))* // num non-cross attention layers
|
||||
context +
|
||||
4* // sizeof(float32)
|
||||
uint64(crossAttentionLayers.size)* // num cross attention layers
|
||||
visionTokens*
|
||||
tiles)
|
||||
crossAttentionLayers := f.KV().Ints("attention.cross_attention_layers")
|
||||
for i := range kv {
|
||||
if slices.Contains(crossAttentionLayers, int32(i)) {
|
||||
kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) *
|
||||
4 * // sizeof(float32)
|
||||
visionTokens *
|
||||
tiles
|
||||
}
|
||||
}
|
||||
|
||||
fullOffload = max(
|
||||
@ -481,7 +491,7 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
||||
var ropeFreqsCount uint64
|
||||
if ropeFreqs, ok := f.Tensors().GroupLayers()["rope_freqs"]; ok {
|
||||
if ropeFreqsWeights, ok := ropeFreqs["weights"]; ok {
|
||||
ropeFreqsCount = ropeFreqsWeights.parameters()
|
||||
ropeFreqsCount = ropeFreqsWeights.Elements()
|
||||
}
|
||||
}
|
||||
|
||||
@ -505,6 +515,20 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
||||
4*embeddingHeadsK*context*8+
|
||||
embedding*embeddingHeadsK*heads*9/16,
|
||||
)
|
||||
|
||||
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
|
||||
// engine. Gemma3 always uses the Ollama engine.
|
||||
if f.KV().Architecture() == "gemma3" {
|
||||
const gemma3GlobalCacheCount = 6
|
||||
slidingWindow := (uint64(numParallel) * uint64(f.KV().Uint("attention.sliding_window"))) + batch
|
||||
for i := range kv {
|
||||
// Every 6th layer is a global layer, which is the full context size that has already been set. The other
|
||||
// layers are the smaller local (sliding) layers.
|
||||
if (i+1)%gemma3GlobalCacheCount != 0 {
|
||||
kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
}
|
||||
}
|
||||
}
|
||||
case "command-r":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
@ -623,10 +647,36 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
|
||||
embeddingLength*numPatches*maxNumTiles +
|
||||
9*embeddingLength*numPaddedPatches*maxNumTiles +
|
||||
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
|
||||
case "gemma3":
|
||||
case "gemma3", "mistral3":
|
||||
graphSize = 4 * (imageSize*imageSize*numChannels +
|
||||
embeddingLength*patchSize +
|
||||
numPatches*numPatches*headCount)
|
||||
case "qwen25vl":
|
||||
maxPixels := uint64(llm.KV().Uint("vision.max_pixels", 28*28*1280))
|
||||
mergeSize := uint64(llm.KV().Uint("vision.spatial_merge_size", 2))
|
||||
temporalPatchSize := uint64(2)
|
||||
|
||||
// Calculate max possible patches based on max_pixels
|
||||
maxHeight := uint64(math.Sqrt(float64(maxPixels)))
|
||||
maxWidth := maxPixels / maxHeight
|
||||
maxGridHeight := maxHeight / patchSize
|
||||
maxGridWidth := maxWidth / patchSize
|
||||
// Account for merged patches (2x2 grid)
|
||||
numPatches := (maxGridHeight * maxGridWidth) / (mergeSize * mergeSize)
|
||||
|
||||
// Calculate graph size based on typical operations in ProcessImage and createPatches
|
||||
graphSize = 4 * (maxPixels*numChannels + // Original image storage
|
||||
// Normalized pixels
|
||||
maxPixels*numChannels +
|
||||
// Patches storage (numPatches * channels * temporalPatchSize * patchSize^2)
|
||||
numPatches*numChannels*temporalPatchSize*patchSize*patchSize +
|
||||
// Self-attention calculations (similar to other architectures)
|
||||
numPatches*numPatches*headCount +
|
||||
// Additional buffer for processing
|
||||
embeddingLength*numPatches)
|
||||
case "llama4":
|
||||
// vision graph is computed independently in the same schedule
|
||||
// and is negligible compared to the worst case text graph
|
||||
}
|
||||
|
||||
return weights, graphSize
|
||||
|
@ -2,6 +2,7 @@ package ggml
|
||||
|
||||
import (
|
||||
"maps"
|
||||
"math"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -210,3 +211,61 @@ func TestTensorTypes(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyValue(t *testing.T) {
|
||||
kv := KV{
|
||||
"general.architecture": "test",
|
||||
"test.strings": &array[string]{size: 3, values: []string{"a", "b", "c"}},
|
||||
"test.float32s": &array[float32]{size: 3, values: []float32{1.0, 2.0, 3.0}},
|
||||
"test.int32s": &array[int32]{size: 3, values: []int32{1, 2, 3}},
|
||||
"test.uint32s": &array[uint32]{size: 3, values: []uint32{1, 2, 3}},
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(kv.Strings("strings"), []string{"a", "b", "c"}); diff != "" {
|
||||
t.Errorf("unexpected strings (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(kv.Strings("nonexistent.strings"), []string(nil)); diff != "" {
|
||||
t.Errorf("unexpected strings (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(kv.Strings("default.strings", []string{"ollama"}), []string{"ollama"}); diff != "" {
|
||||
t.Errorf("unexpected strings (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(kv.Floats("float32s"), []float32{1.0, 2.0, 3.0}); diff != "" {
|
||||
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(kv.Floats("nonexistent.float32s"), []float32(nil)); diff != "" {
|
||||
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(kv.Floats("default.float32s", []float32{math.MaxFloat32}), []float32{math.MaxFloat32}); diff != "" {
|
||||
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(kv.Ints("int32s"), []int32{1, 2, 3}); diff != "" {
|
||||
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(kv.Ints("nonexistent.int32s"), []int32(nil)); diff != "" {
|
||||
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(kv.Ints("default.int32s", []int32{math.MaxInt32}), []int32{math.MaxInt32}); diff != "" {
|
||||
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(kv.Uints("uint32s"), []uint32{1, 2, 3}); diff != "" {
|
||||
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(kv.Uints("nonexistent.uint32s"), []uint32(nil)); diff != "" {
|
||||
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(kv.Uints("default.uint32s", []uint32{math.MaxUint32}), []uint32{math.MaxUint32}); diff != "" {
|
||||
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
321
fs/ggml/gguf.go
321
fs/ggml/gguf.go
@ -9,8 +9,12 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
type containerGGUF struct {
|
||||
@ -36,10 +40,6 @@ type containerGGUF struct {
|
||||
maxArraySize int
|
||||
}
|
||||
|
||||
func (c *containerGGUF) canCollectArray(size int) bool {
|
||||
return c.maxArraySize < 0 || size <= c.maxArraySize
|
||||
}
|
||||
|
||||
func (c *containerGGUF) Name() string {
|
||||
return "gguf"
|
||||
}
|
||||
@ -229,16 +229,13 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
|
||||
}
|
||||
|
||||
llm.tensors = append(llm.tensors, &tensor)
|
||||
llm.parameters += tensor.parameters()
|
||||
llm.parameters += tensor.Elements()
|
||||
}
|
||||
|
||||
// patch KV with parameter count
|
||||
llm.kv["general.parameter_count"] = llm.parameters
|
||||
|
||||
alignment, ok := llm.kv["general.alignment"].(uint32)
|
||||
if !ok {
|
||||
alignment = 32
|
||||
}
|
||||
alignment := llm.kv.Uint("general.alignment", 32)
|
||||
|
||||
offset, err := rs.Seek(0, io.SeekCurrent)
|
||||
if err != nil {
|
||||
@ -298,6 +295,23 @@ func readGGUFV1String(llm *gguf, r io.Reader) (string, error) {
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
func readGGUFV1StringsData(llm *gguf, r io.Reader, a *array[string]) (any, error) {
|
||||
for i := range a.size {
|
||||
if a.values != nil {
|
||||
e, err := readGGUFV1String(llm, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
a.values[i] = e
|
||||
} else {
|
||||
discardGGUFString(llm, r)
|
||||
}
|
||||
}
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func discardGGUFString(llm *gguf, r io.Reader) error {
|
||||
buf := llm.scratch[:8]
|
||||
_, err := io.ReadFull(r, buf)
|
||||
@ -355,78 +369,44 @@ func writeGGUFString(w io.Writer, s string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
type array struct {
|
||||
size int
|
||||
values []any
|
||||
}
|
||||
|
||||
func (a *array) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(a.values)
|
||||
}
|
||||
|
||||
func readGGUFV1Array(llm *gguf, r io.Reader) (*array, error) {
|
||||
t, err := readGGUF[uint32](llm, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
n, err := readGGUF[uint32](llm, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
a := &array{size: int(n)}
|
||||
if llm.canCollectArray(int(n)) {
|
||||
a.values = make([]any, 0, int(n))
|
||||
}
|
||||
|
||||
for i := range n {
|
||||
var e any
|
||||
switch t {
|
||||
case ggufTypeUint8:
|
||||
e, err = readGGUF[uint8](llm, r)
|
||||
case ggufTypeInt8:
|
||||
e, err = readGGUF[int8](llm, r)
|
||||
case ggufTypeUint16:
|
||||
e, err = readGGUF[uint16](llm, r)
|
||||
case ggufTypeInt16:
|
||||
e, err = readGGUF[int16](llm, r)
|
||||
case ggufTypeUint32:
|
||||
e, err = readGGUF[uint32](llm, r)
|
||||
case ggufTypeInt32:
|
||||
e, err = readGGUF[int32](llm, r)
|
||||
case ggufTypeUint64:
|
||||
e, err = readGGUF[uint64](llm, r)
|
||||
case ggufTypeInt64:
|
||||
e, err = readGGUF[int64](llm, r)
|
||||
case ggufTypeFloat32:
|
||||
e, err = readGGUF[float32](llm, r)
|
||||
case ggufTypeFloat64:
|
||||
e, err = readGGUF[float64](llm, r)
|
||||
case ggufTypeBool:
|
||||
e, err = readGGUF[bool](llm, r)
|
||||
case ggufTypeString:
|
||||
e, err = readGGUFV1String(llm, r)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid array type: %d", t)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func readGGUFStringsData(llm *gguf, r io.Reader, a *array[string]) (any, error) {
|
||||
for i := range a.size {
|
||||
if a.values != nil {
|
||||
e, err := readGGUFString(llm, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
a.values[i] = e
|
||||
} else {
|
||||
discardGGUFString(llm, r)
|
||||
}
|
||||
}
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
|
||||
if llm.Version == 1 {
|
||||
return readGGUFV1Array(llm, r)
|
||||
}
|
||||
type array[T any] struct {
|
||||
// size is the actual size of the array
|
||||
size int
|
||||
|
||||
// values is the array of values. this is nil if the array is larger than configured maxSize
|
||||
values []T
|
||||
}
|
||||
|
||||
func (a *array[T]) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(a.values)
|
||||
}
|
||||
|
||||
func newArray[T any](size, maxSize int) *array[T] {
|
||||
a := array[T]{size: size}
|
||||
if maxSize < 0 || size <= maxSize {
|
||||
a.values = make([]T, size)
|
||||
}
|
||||
return &a
|
||||
}
|
||||
|
||||
func readGGUFArray(llm *gguf, r io.Reader) (any, error) {
|
||||
t, err := readGGUF[uint32](llm, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -437,45 +417,55 @@ func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
a := &array{size: int(n)}
|
||||
if llm.canCollectArray(int(n)) {
|
||||
a.values = make([]any, int(n))
|
||||
}
|
||||
|
||||
for i := range n {
|
||||
var e any
|
||||
switch t {
|
||||
case ggufTypeUint8:
|
||||
e, err = readGGUF[uint8](llm, r)
|
||||
case ggufTypeInt8:
|
||||
e, err = readGGUF[int8](llm, r)
|
||||
case ggufTypeUint16:
|
||||
e, err = readGGUF[uint16](llm, r)
|
||||
case ggufTypeInt16:
|
||||
e, err = readGGUF[int16](llm, r)
|
||||
case ggufTypeUint32:
|
||||
e, err = readGGUF[uint32](llm, r)
|
||||
case ggufTypeInt32:
|
||||
e, err = readGGUF[int32](llm, r)
|
||||
case ggufTypeUint64:
|
||||
e, err = readGGUF[uint64](llm, r)
|
||||
case ggufTypeInt64:
|
||||
e, err = readGGUF[int64](llm, r)
|
||||
case ggufTypeFloat32:
|
||||
e, err = readGGUF[float32](llm, r)
|
||||
case ggufTypeFloat64:
|
||||
e, err = readGGUF[float64](llm, r)
|
||||
case ggufTypeBool:
|
||||
e, err = readGGUF[bool](llm, r)
|
||||
case ggufTypeString:
|
||||
if a.values != nil {
|
||||
e, err = readGGUFString(llm, r)
|
||||
} else {
|
||||
err = discardGGUFString(llm, r)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid array type: %d", t)
|
||||
switch t {
|
||||
case ggufTypeUint8:
|
||||
a := newArray[uint8](int(n), llm.maxArraySize)
|
||||
return readGGUFArrayData(llm, r, a)
|
||||
case ggufTypeInt8:
|
||||
a := newArray[int8](int(n), llm.maxArraySize)
|
||||
return readGGUFArrayData(llm, r, a)
|
||||
case ggufTypeUint16:
|
||||
a := newArray[uint16](int(n), llm.maxArraySize)
|
||||
return readGGUFArrayData(llm, r, a)
|
||||
case ggufTypeInt16:
|
||||
a := newArray[int16](int(n), llm.maxArraySize)
|
||||
return readGGUFArrayData(llm, r, a)
|
||||
case ggufTypeUint32:
|
||||
a := newArray[uint32](int(n), llm.maxArraySize)
|
||||
return readGGUFArrayData(llm, r, a)
|
||||
case ggufTypeInt32:
|
||||
a := newArray[int32](int(n), llm.maxArraySize)
|
||||
return readGGUFArrayData(llm, r, a)
|
||||
case ggufTypeUint64:
|
||||
a := newArray[uint64](int(n), llm.maxArraySize)
|
||||
return readGGUFArrayData(llm, r, a)
|
||||
case ggufTypeInt64:
|
||||
a := newArray[int64](int(n), llm.maxArraySize)
|
||||
return readGGUFArrayData(llm, r, a)
|
||||
case ggufTypeFloat32:
|
||||
a := newArray[float32](int(n), llm.maxArraySize)
|
||||
return readGGUFArrayData(llm, r, a)
|
||||
case ggufTypeFloat64:
|
||||
a := newArray[float64](int(n), llm.maxArraySize)
|
||||
return readGGUFArrayData(llm, r, a)
|
||||
case ggufTypeBool:
|
||||
a := newArray[bool](int(n), llm.maxArraySize)
|
||||
return readGGUFArrayData(llm, r, a)
|
||||
case ggufTypeString:
|
||||
a := newArray[string](int(n), llm.maxArraySize)
|
||||
if llm.Version == 1 {
|
||||
return readGGUFV1StringsData(llm, r, a)
|
||||
}
|
||||
|
||||
return readGGUFStringsData(llm, r, a)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid array type: %d", t)
|
||||
}
|
||||
}
|
||||
|
||||
func readGGUFArrayData[T any](llm *gguf, r io.Reader, a *array[T]) (any, error) {
|
||||
for i := range a.size {
|
||||
e, err := readGGUF[T](llm, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -502,23 +492,38 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if t == ggufTypeString {
|
||||
for _, e := range any(s).([]string) {
|
||||
if err := binary.Write(w, binary.LittleEndian, uint64(len(e))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(w, binary.LittleEndian, []byte(e)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return binary.Write(w, binary.LittleEndian, s)
|
||||
}
|
||||
|
||||
func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
|
||||
if err := binary.Write(ws, binary.LittleEndian, []byte("GGUF")); err != nil {
|
||||
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
alignment := kv.Uint("general.alignment", 32)
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, []byte("GGUF")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(ws, binary.LittleEndian, uint32(3)); err != nil {
|
||||
if err := binary.Write(f, binary.LittleEndian, uint32(3)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(ws, binary.LittleEndian, uint64(len(ts))); err != nil {
|
||||
if err := binary.Write(f, binary.LittleEndian, uint64(len(ts))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(ws, binary.LittleEndian, uint64(len(kv))); err != nil {
|
||||
if err := binary.Write(f, binary.LittleEndian, uint64(len(kv))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -526,12 +531,12 @@ func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
|
||||
slices.Sort(keys)
|
||||
|
||||
for _, key := range keys {
|
||||
if err := ggufWriteKV(ws, key, kv[key]); err != nil {
|
||||
if err := ggufWriteKV(f, key, kv[key]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
slices.SortStableFunc(ts, func(a, b Tensor) int {
|
||||
slices.SortStableFunc(ts, func(a, b *Tensor) int {
|
||||
if i, j := a.block(), b.block(); i < 0 && j > 0 {
|
||||
return 1
|
||||
} else if i > 0 && j < 0 {
|
||||
@ -542,22 +547,34 @@ func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
|
||||
})
|
||||
|
||||
var s uint64
|
||||
for _, t := range ts {
|
||||
t.Offset = s
|
||||
if err := ggufWriteTensorInfo(ws, t); err != nil {
|
||||
for i := range ts {
|
||||
ts[i].Offset = s
|
||||
if err := ggufWriteTensorInfo(f, ts[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
s += t.Size()
|
||||
s += ts[i].Size()
|
||||
s += uint64(ggufPadding(int64(s), int64(alignment)))
|
||||
}
|
||||
|
||||
var alignment int64 = 32
|
||||
offset, err := f.Seek(0, io.SeekCurrent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
offset += ggufPadding(offset, int64(alignment))
|
||||
|
||||
var g errgroup.Group
|
||||
g.SetLimit(runtime.GOMAXPROCS(0))
|
||||
// TODO consider reducing if tensors size * gomaxprocs is larger than free memory
|
||||
for _, t := range ts {
|
||||
if err := ggufWriteTensor(ws, t, alignment); err != nil {
|
||||
t := t
|
||||
w := io.NewOffsetWriter(f, offset+int64(t.Offset))
|
||||
g.Go(func() error {
|
||||
_, err := t.WriteTo(w)
|
||||
return err
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
|
||||
@ -572,8 +589,10 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
|
||||
|
||||
var err error
|
||||
switch v := v.(type) {
|
||||
case uint32:
|
||||
case uint32, FileType:
|
||||
err = writeGGUF(ws, ggufTypeUint32, v)
|
||||
case uint64:
|
||||
err = writeGGUF(ws, ggufTypeUint64, v)
|
||||
case float32:
|
||||
err = writeGGUF(ws, ggufTypeFloat32, v)
|
||||
case bool:
|
||||
@ -582,32 +601,20 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
|
||||
err = writeGGUFString(ws, v)
|
||||
case []int32:
|
||||
err = writeGGUFArray(ws, ggufTypeInt32, v)
|
||||
case *array[int32]:
|
||||
err = writeGGUFArray(ws, ggufTypeInt32, v.values)
|
||||
case []uint32:
|
||||
err = writeGGUFArray(ws, ggufTypeUint32, v)
|
||||
case *array[uint32]:
|
||||
err = writeGGUFArray(ws, ggufTypeUint32, v.values)
|
||||
case []float32:
|
||||
err = writeGGUFArray(ws, ggufTypeFloat32, v)
|
||||
case *array[float32]:
|
||||
err = writeGGUFArray(ws, ggufTypeFloat32, v.values)
|
||||
case []string:
|
||||
if err := binary.Write(ws, binary.LittleEndian, ggufTypeArray); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(ws, binary.LittleEndian, ggufTypeString); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(ws, binary.LittleEndian, uint64(len(v))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, e := range v {
|
||||
if err := binary.Write(ws, binary.LittleEndian, uint64(len(e))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(ws, binary.LittleEndian, []byte(e)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err = writeGGUFArray(ws, ggufTypeString, v)
|
||||
case *array[string]:
|
||||
err = writeGGUFArray(ws, ggufTypeString, v.values)
|
||||
default:
|
||||
return fmt.Errorf("improper type for '%s'", k)
|
||||
}
|
||||
@ -615,7 +622,7 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func ggufWriteTensorInfo(ws io.WriteSeeker, t Tensor) error {
|
||||
func ggufWriteTensorInfo(ws io.WriteSeeker, t *Tensor) error {
|
||||
slog.Debug(t.Name, "kind", t.Kind, "shape", t.Shape, "offset", t.Offset)
|
||||
if err := binary.Write(ws, binary.LittleEndian, uint64(len(t.Name))); err != nil {
|
||||
return err
|
||||
@ -629,8 +636,8 @@ func ggufWriteTensorInfo(ws io.WriteSeeker, t Tensor) error {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := range len(t.Shape) {
|
||||
if err := binary.Write(ws, binary.LittleEndian, t.Shape[len(t.Shape)-i-1]); err != nil {
|
||||
for _, n := range t.Shape {
|
||||
if err := binary.Write(ws, binary.LittleEndian, n); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -642,20 +649,6 @@ func ggufWriteTensorInfo(ws io.WriteSeeker, t Tensor) error {
|
||||
return binary.Write(ws, binary.LittleEndian, t.Offset)
|
||||
}
|
||||
|
||||
func ggufWriteTensor(ws io.WriteSeeker, t Tensor, alignment int64) error {
|
||||
offset, err := ws.Seek(0, io.SeekCurrent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(ws, binary.LittleEndian, bytes.Repeat([]byte{0}, int(ggufPadding(offset, alignment)))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = t.WriteTo(ws)
|
||||
return err
|
||||
}
|
||||
|
||||
func ggufPadding(offset, align int64) int64 {
|
||||
return (align - offset%align) % align
|
||||
}
|
||||
|
63
fs/ggml/gguf_test.go
Normal file
63
fs/ggml/gguf_test.go
Normal file
@ -0,0 +1,63 @@
|
||||
package ggml
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestWriteGGUF(t *testing.T) {
|
||||
w, err := os.CreateTemp(t.TempDir(), "*.bin")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer w.Close()
|
||||
|
||||
if err := WriteGGUF(w, KV{
|
||||
"general.alignment": uint32(16),
|
||||
}, []*Tensor{
|
||||
{Name: "test.0", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.1", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.2", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.3", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.4", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
{Name: "test.5", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r, err := os.Open(w.Name())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
ff, _, err := Decode(r, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(ff.KV(), KV{
|
||||
"general.alignment": uint32(16),
|
||||
"general.parameter_count": uint64(36),
|
||||
}); diff != "" {
|
||||
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(ff.Tensors(), Tensors{
|
||||
Offset: 336,
|
||||
items: []*Tensor{
|
||||
{Name: "test.0", Offset: 0, Shape: []uint64{2, 3}},
|
||||
{Name: "test.1", Offset: 32, Shape: []uint64{2, 3}},
|
||||
{Name: "test.2", Offset: 64, Shape: []uint64{2, 3}},
|
||||
{Name: "test.3", Offset: 96, Shape: []uint64{2, 3}},
|
||||
{Name: "test.4", Offset: 128, Shape: []uint64{2, 3}},
|
||||
{Name: "test.5", Offset: 160, Shape: []uint64{2, 3}},
|
||||
},
|
||||
}, cmp.AllowUnexported(Tensors{})); diff != "" {
|
||||
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
343
fs/ggml/type.go
343
fs/ggml/type.go
@ -1,26 +1,31 @@
|
||||
package ggml
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type fileType uint32
|
||||
// FileType is the Go equivalent to llama_ftype used for gguf file typing
|
||||
type FileType uint32
|
||||
|
||||
const (
|
||||
fileTypeF32 fileType = iota
|
||||
fileTypeF16
|
||||
FileTypeF32 FileType = iota
|
||||
FileTypeF16
|
||||
fileTypeQ4_0
|
||||
fileTypeQ4_1
|
||||
fileTypeQ4_1_F16
|
||||
fileTypeQ4_2 // unused
|
||||
fileTypeQ4_3 // unused
|
||||
fileTypeQ8_0
|
||||
fileTypeQ4_1_F16 // unused by GGML
|
||||
fileTypeQ4_2 // unused by GGML
|
||||
fileTypeQ4_3 // unused by GGML
|
||||
FileTypeQ8_0
|
||||
fileTypeQ5_0
|
||||
fileTypeQ5_1
|
||||
fileTypeQ2_K
|
||||
fileTypeQ3_K_S
|
||||
fileTypeQ3_K_M
|
||||
fileTypeQ3_K_L
|
||||
fileTypeQ4_K_S
|
||||
fileTypeQ4_K_M
|
||||
FileTypeQ4_K_S
|
||||
FileTypeQ4_K_M
|
||||
fileTypeQ5_K_S
|
||||
fileTypeQ5_K_M
|
||||
fileTypeQ6_K
|
||||
@ -37,93 +42,62 @@ const (
|
||||
fileTypeIQ2_M
|
||||
fileTypeIQ4_XS
|
||||
fileTypeIQ1_M
|
||||
fileTypeBF16
|
||||
FileTypeBF16
|
||||
fileTypeQ4_0_4_4 // unused by GGML
|
||||
fileTypeQ4_0_4_8 // unused by GGML
|
||||
fileTypeQ4_0_8_8 // unused by GGML
|
||||
fileTypeTQ1_0
|
||||
fileTypeTQ2_0
|
||||
|
||||
fileTypeUnknown
|
||||
FileTypeUnknown = 1024
|
||||
)
|
||||
|
||||
func ParseFileType(s string) (fileType, error) {
|
||||
// ParseFileType parses the provided GGUF file type
|
||||
// Only Ollama supported types are considered valid
|
||||
func ParseFileType(s string) (FileType, error) {
|
||||
switch s {
|
||||
case "F32":
|
||||
return fileTypeF32, nil
|
||||
return FileTypeF32, nil
|
||||
case "F16":
|
||||
return fileTypeF16, nil
|
||||
case "Q4_0":
|
||||
return fileTypeQ4_0, nil
|
||||
case "Q4_1":
|
||||
return fileTypeQ4_1, nil
|
||||
case "Q4_1_F16":
|
||||
return fileTypeQ4_1_F16, nil
|
||||
return FileTypeF16, nil
|
||||
case "Q8_0":
|
||||
return fileTypeQ8_0, nil
|
||||
case "Q5_0":
|
||||
return fileTypeQ5_0, nil
|
||||
case "Q5_1":
|
||||
return fileTypeQ5_1, nil
|
||||
case "Q2_K":
|
||||
return fileTypeQ2_K, nil
|
||||
case "Q3_K_S":
|
||||
return fileTypeQ3_K_S, nil
|
||||
case "Q3_K_M":
|
||||
return fileTypeQ3_K_M, nil
|
||||
case "Q3_K_L":
|
||||
return fileTypeQ3_K_L, nil
|
||||
return FileTypeQ8_0, nil
|
||||
case "Q4_K_S":
|
||||
return fileTypeQ4_K_S, nil
|
||||
case "Q4_K_M":
|
||||
return fileTypeQ4_K_M, nil
|
||||
case "Q5_K_S":
|
||||
return fileTypeQ5_K_S, nil
|
||||
case "Q5_K_M":
|
||||
return fileTypeQ5_K_M, nil
|
||||
case "Q6_K":
|
||||
return fileTypeQ6_K, nil
|
||||
case "IQ2_XXS":
|
||||
return fileTypeIQ2_XXS, nil
|
||||
case "IQ2_XS":
|
||||
return fileTypeIQ2_XS, nil
|
||||
case "Q2_K_S":
|
||||
return fileTypeQ2_K_S, nil
|
||||
case "IQ3_XS":
|
||||
return fileTypeIQ3_XS, nil
|
||||
case "IQ3_XXS":
|
||||
return fileTypeIQ3_XXS, nil
|
||||
case "IQ1_S":
|
||||
return fileTypeIQ1_S, nil
|
||||
case "IQ4_NL":
|
||||
return fileTypeIQ4_NL, nil
|
||||
case "IQ3_S":
|
||||
return fileTypeIQ3_S, nil
|
||||
case "IQ3_M":
|
||||
return fileTypeIQ3_M, nil
|
||||
case "IQ2_S":
|
||||
return fileTypeIQ2_S, nil
|
||||
case "IQ2_M":
|
||||
return fileTypeIQ2_M, nil
|
||||
case "IQ4_XS":
|
||||
return fileTypeIQ4_XS, nil
|
||||
case "IQ1_M":
|
||||
return fileTypeIQ1_M, nil
|
||||
return FileTypeQ4_K_S, nil
|
||||
case "Q4_K_M", "Q4_K":
|
||||
return FileTypeQ4_K_M, nil
|
||||
case "BF16":
|
||||
return fileTypeBF16, nil
|
||||
return FileTypeBF16, nil
|
||||
default:
|
||||
return fileTypeUnknown, fmt.Errorf("unknown fileType: %s", s)
|
||||
supportedFileTypes := []FileType{
|
||||
FileTypeF32,
|
||||
FileTypeF16,
|
||||
FileTypeQ4_K_S,
|
||||
FileTypeQ4_K_M,
|
||||
FileTypeQ8_0,
|
||||
// fsggml.FileTypeBF16, // TODO
|
||||
}
|
||||
strs := make([]string, len(supportedFileTypes))
|
||||
for i := range supportedFileTypes {
|
||||
strs[i] = supportedFileTypes[i].String()
|
||||
}
|
||||
|
||||
return FileTypeUnknown, fmt.Errorf("unsupported quantization type %s - supported types are %s", s, strings.Join(strs, ", "))
|
||||
}
|
||||
}
|
||||
|
||||
func (t fileType) String() string {
|
||||
func (t FileType) String() string {
|
||||
// Note: this routine will return a broader set of file types for existing models
|
||||
switch t {
|
||||
case fileTypeF32:
|
||||
case FileTypeF32:
|
||||
return "F32"
|
||||
case fileTypeF16:
|
||||
case FileTypeF16:
|
||||
return "F16"
|
||||
case fileTypeQ4_0:
|
||||
return "Q4_0"
|
||||
case fileTypeQ4_1:
|
||||
return "Q4_1"
|
||||
case fileTypeQ4_1_F16:
|
||||
return "Q4_1_F16"
|
||||
case fileTypeQ8_0:
|
||||
case FileTypeQ8_0:
|
||||
return "Q8_0"
|
||||
case fileTypeQ5_0:
|
||||
return "Q5_0"
|
||||
@ -137,9 +111,9 @@ func (t fileType) String() string {
|
||||
return "Q3_K_M"
|
||||
case fileTypeQ3_K_L:
|
||||
return "Q3_K_L"
|
||||
case fileTypeQ4_K_S:
|
||||
case FileTypeQ4_K_S:
|
||||
return "Q4_K_S"
|
||||
case fileTypeQ4_K_M:
|
||||
case FileTypeQ4_K_M:
|
||||
return "Q4_K_M"
|
||||
case fileTypeQ5_K_S:
|
||||
return "Q5_K_S"
|
||||
@ -147,39 +121,198 @@ func (t fileType) String() string {
|
||||
return "Q5_K_M"
|
||||
case fileTypeQ6_K:
|
||||
return "Q6_K"
|
||||
case fileTypeIQ2_XXS:
|
||||
return "IQ2_XXS"
|
||||
case fileTypeIQ2_XS:
|
||||
return "IQ2_XS"
|
||||
case fileTypeQ2_K_S:
|
||||
return "Q2_K_S"
|
||||
case fileTypeIQ3_XS:
|
||||
return "IQ3_XS"
|
||||
case fileTypeIQ3_XXS:
|
||||
return "IQ3_XXS"
|
||||
case fileTypeIQ1_S:
|
||||
return "IQ1_S"
|
||||
case fileTypeIQ4_NL:
|
||||
return "IQ4_NL"
|
||||
case fileTypeIQ3_S:
|
||||
return "IQ3_S"
|
||||
case fileTypeIQ3_M:
|
||||
return "IQ3_M"
|
||||
case fileTypeIQ2_S:
|
||||
return "IQ2_S"
|
||||
case fileTypeIQ4_XS:
|
||||
return "IQ4_XS"
|
||||
case fileTypeIQ2_M:
|
||||
return "IQ2_M"
|
||||
case fileTypeIQ1_M:
|
||||
return "IQ1_M"
|
||||
case fileTypeBF16:
|
||||
case FileTypeBF16:
|
||||
return "BF16"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (t fileType) Value() uint32 {
|
||||
func (t FileType) Value() uint32 {
|
||||
return uint32(t)
|
||||
}
|
||||
|
||||
func (ftype FileType) ToTensorType() TensorType {
|
||||
switch ftype {
|
||||
case FileTypeF32:
|
||||
return TensorTypeF32
|
||||
case FileTypeF16:
|
||||
return TensorTypeF16
|
||||
case fileTypeQ4_0:
|
||||
return TensorTypeQ4_0
|
||||
case fileTypeQ4_1:
|
||||
return TensorTypeQ4_1
|
||||
case FileTypeQ8_0:
|
||||
return TensorTypeQ8_0
|
||||
case fileTypeQ5_0:
|
||||
return TensorTypeQ5_0
|
||||
case fileTypeQ5_1:
|
||||
return TensorTypeQ5_1
|
||||
case fileTypeQ2_K:
|
||||
return TensorTypeQ2_K
|
||||
case fileTypeQ3_K_S:
|
||||
return TensorTypeQ3_K
|
||||
case fileTypeQ3_K_M:
|
||||
return TensorTypeQ3_K
|
||||
case fileTypeQ3_K_L:
|
||||
return TensorTypeQ3_K
|
||||
case FileTypeQ4_K_S:
|
||||
return TensorTypeQ4_K
|
||||
case FileTypeQ4_K_M:
|
||||
return TensorTypeQ4_K
|
||||
case fileTypeQ5_K_S:
|
||||
return TensorTypeQ5_K
|
||||
case fileTypeQ5_K_M:
|
||||
return TensorTypeQ5_K
|
||||
case fileTypeQ6_K:
|
||||
return TensorTypeQ6_K
|
||||
case fileTypeQ2_K_S:
|
||||
return TensorTypeQ2_K
|
||||
case FileTypeBF16:
|
||||
return TensorTypeBF16
|
||||
default:
|
||||
slog.Warn("unsupported file type", "type", ftype)
|
||||
return 0 // F32
|
||||
}
|
||||
}
|
||||
|
||||
// TensorType is equivalent to ggml_type for individual tensor types
|
||||
// Note: these are not the same as FileType
|
||||
type TensorType uint32
|
||||
|
||||
const (
|
||||
TensorTypeF32 TensorType = iota
|
||||
TensorTypeF16
|
||||
TensorTypeQ4_0
|
||||
TensorTypeQ4_1
|
||||
tensorTypeQ4_2 // unused by GGML
|
||||
tensorTypeQ4_3 // unused by GGML
|
||||
TensorTypeQ5_0
|
||||
TensorTypeQ5_1
|
||||
TensorTypeQ8_0
|
||||
TensorTypeQ8_1
|
||||
TensorTypeQ2_K
|
||||
TensorTypeQ3_K
|
||||
TensorTypeQ4_K
|
||||
TensorTypeQ5_K
|
||||
TensorTypeQ6_K
|
||||
TensorTypeQ8_K
|
||||
tensorTypeIQ2_XXS // not supported by ollama
|
||||
tensorTypeIQ2_XS // not supported by ollama
|
||||
tensorTypeIQ3_XXS // not supported by ollama
|
||||
tensorTypeIQ1_S // not supported by ollama
|
||||
tensorTypeIQ4_NL // not supported by ollama
|
||||
tensorTypeIQ3_S // not supported by ollama
|
||||
tensorTypeIQ2_S // not supported by ollama
|
||||
tensorTypeIQ4_XS // not supported by ollama
|
||||
TensorTypeI8
|
||||
TensorTypeI16
|
||||
TensorTypeI32
|
||||
TensorTypeI64
|
||||
TensorTypeF64
|
||||
tensorTypeIQ1_M // not supported by ollama
|
||||
TensorTypeBF16
|
||||
tensorTypeQ4_0_4_4 // unused by GGML
|
||||
tensorTypeQ4_0_4_8 // unused by GGML
|
||||
tensorTypeQ4_0_8_8 // unused by GGML
|
||||
tensorTypeTQ1_0 // not supported by ollama
|
||||
tensorTypeTQ2_0 // not supported by ollama
|
||||
tensorTypeIQ4_NL_4_4 // unused by GGML
|
||||
tensorTypeIQ4_NL_4_8 // unused by GGML
|
||||
tensorTypeIQ4_NL_8_8 // unused by GGML
|
||||
)
|
||||
|
||||
// ParseFileType parses the provided GGUF file type
|
||||
// Only Ollama supported types are considered valid
|
||||
func ParseTensorType(s string) (TensorType, error) {
|
||||
switch s {
|
||||
case "F32":
|
||||
return TensorTypeF32, nil
|
||||
case "F16":
|
||||
return TensorTypeF16, nil
|
||||
case "Q4_0":
|
||||
return TensorTypeQ4_0, nil
|
||||
case "Q4_1":
|
||||
return TensorTypeQ4_1, nil
|
||||
case "Q5_0":
|
||||
return TensorTypeQ5_0, nil
|
||||
case "Q5_1":
|
||||
return TensorTypeQ5_1, nil
|
||||
case "Q8_0":
|
||||
return TensorTypeQ8_0, nil
|
||||
case "Q8_1":
|
||||
return TensorTypeQ8_1, nil
|
||||
case "Q2_K":
|
||||
return TensorTypeQ2_K, nil
|
||||
case "Q3_K":
|
||||
return TensorTypeQ3_K, nil
|
||||
case "Q4_K":
|
||||
return TensorTypeQ4_K, nil
|
||||
case "Q5_K":
|
||||
return TensorTypeQ5_K, nil
|
||||
case "Q6_K":
|
||||
return TensorTypeQ6_K, nil
|
||||
case "Q8_K":
|
||||
return TensorTypeQ8_K, nil
|
||||
case "F64":
|
||||
return TensorTypeF64, nil
|
||||
case "BF16":
|
||||
return TensorTypeBF16, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported quantization type %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
func (t TensorType) IsQuantized() bool {
|
||||
switch t {
|
||||
case TensorTypeF32, TensorTypeF16, TensorTypeBF16:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (t TensorType) RowSize(ne uint64) uint64 {
|
||||
return t.TypeSize() * ne / t.BlockSize()
|
||||
}
|
||||
|
||||
func (t TensorType) String() string {
|
||||
switch t {
|
||||
case TensorTypeF32:
|
||||
return "F32"
|
||||
case TensorTypeF16:
|
||||
return "F16"
|
||||
case TensorTypeQ4_0:
|
||||
return "Q4_0"
|
||||
case TensorTypeQ4_1:
|
||||
return "Q4_1"
|
||||
case TensorTypeQ5_0:
|
||||
return "Q5_0"
|
||||
case TensorTypeQ5_1:
|
||||
return "Q5_1"
|
||||
case TensorTypeQ8_0:
|
||||
return "Q8_0"
|
||||
case TensorTypeQ8_1:
|
||||
return "Q8_1"
|
||||
case TensorTypeQ2_K:
|
||||
return "Q2_K"
|
||||
case TensorTypeQ3_K:
|
||||
return "Q3_K"
|
||||
case TensorTypeQ4_K:
|
||||
return "Q4_K"
|
||||
case TensorTypeQ5_K:
|
||||
return "Q5_K"
|
||||
case TensorTypeQ6_K:
|
||||
return "Q6_K"
|
||||
case TensorTypeQ8_K:
|
||||
return "Q8_K"
|
||||
case TensorTypeF64:
|
||||
return "F64"
|
||||
case TensorTypeBF16:
|
||||
return "BF16"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
12
go.mod
12
go.mod
@ -11,7 +11,7 @@ require (
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/x448/float16 v0.8.4
|
||||
golang.org/x/sync v0.11.0
|
||||
golang.org/x/sync v0.12.0
|
||||
)
|
||||
|
||||
require (
|
||||
@ -70,12 +70,12 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.33.0
|
||||
golang.org/x/crypto v0.36.0
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
|
||||
golang.org/x/net v0.35.0 // indirect
|
||||
golang.org/x/sys v0.30.0
|
||||
golang.org/x/term v0.29.0
|
||||
golang.org/x/text v0.22.0
|
||||
golang.org/x/net v0.38.0 // indirect
|
||||
golang.org/x/sys v0.31.0
|
||||
golang.org/x/term v0.30.0
|
||||
golang.org/x/text v0.23.0
|
||||
google.golang.org/protobuf v1.34.1
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
24
go.sum
24
go.sum
@ -214,8 +214,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
|
||||
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
@ -257,8 +257,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
|
||||
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@ -268,8 +268,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
||||
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
|
||||
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@ -285,17 +285,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
|
||||
golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s=
|
||||
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
|
||||
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
|
||||
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
|
||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
|
412
integration/api_test.go
Normal file
412
integration/api_test.go
Normal file
@ -0,0 +1,412 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestAPIGenerate(t *testing.T) {
|
||||
initialTimeout := 60 * time.Second
|
||||
streamTimeout := 30 * time.Second
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: smol,
|
||||
Prompt: "why is the sky blue? be brief",
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
anyResp := []string{"rayleigh", "scattering"}
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
stream bool
|
||||
}{
|
||||
{
|
||||
name: "stream",
|
||||
stream: true,
|
||||
},
|
||||
{
|
||||
name: "no_stream",
|
||||
stream: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var buf bytes.Buffer
|
||||
fn := func(response api.GenerateResponse) error {
|
||||
// Fields that must always be present
|
||||
if response.Model == "" {
|
||||
t.Errorf("response missing model: %#v", response)
|
||||
}
|
||||
if response.Done {
|
||||
// Required fields for final updates:
|
||||
if response.DoneReason == "" && *req.Stream {
|
||||
// TODO - is the lack of done reason on non-stream a bug?
|
||||
t.Errorf("final response missing done_reason: %#v", response)
|
||||
}
|
||||
if response.Metrics.TotalDuration == 0 {
|
||||
t.Errorf("final response missing total_duration: %#v", response)
|
||||
}
|
||||
if response.Metrics.LoadDuration == 0 {
|
||||
t.Errorf("final response missing load_duration: %#v", response)
|
||||
}
|
||||
if response.Metrics.PromptEvalDuration == 0 {
|
||||
t.Errorf("final response missing prompt_eval_duration: %#v", response)
|
||||
}
|
||||
if response.Metrics.EvalCount == 0 {
|
||||
t.Errorf("final response missing eval_count: %#v", response)
|
||||
}
|
||||
if response.Metrics.EvalDuration == 0 {
|
||||
t.Errorf("final response missing eval_duration: %#v", response)
|
||||
}
|
||||
if len(response.Context) == 0 {
|
||||
t.Errorf("final response missing context: %#v", response)
|
||||
}
|
||||
|
||||
// Note: caching can result in no prompt eval count, so this can't be verified reliably
|
||||
// if response.Metrics.PromptEvalCount == 0 {
|
||||
// t.Errorf("final response missing prompt_eval_count: %#v", response)
|
||||
// }
|
||||
|
||||
} // else incremental response, nothing to check right now...
|
||||
buf.Write([]byte(response.Response))
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return fmt.Errorf("stall was detected while streaming response, aborting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
req.Stream = &test.stream
|
||||
req.Options["seed"] = rand.Int() // bust cache for prompt eval results
|
||||
genErr = client.Generate(ctx, &req, fn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
if buf.Len() == 0 {
|
||||
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
|
||||
} else {
|
||||
t.Errorf("generate stalled. Response so far:%s", buf.String())
|
||||
}
|
||||
case <-done:
|
||||
if genErr != nil {
|
||||
t.Fatalf("failed with %s request prompt %s ", req.Model, req.Prompt)
|
||||
}
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Errorf("none of %v found in %s", anyResp, response)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Validate PS while we're at it...
|
||||
resp, err := client.ListRunning(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("list models API error: %s", err)
|
||||
}
|
||||
if resp == nil || len(resp.Models) == 0 {
|
||||
t.Fatalf("list models API returned empty list while model should still be loaded")
|
||||
}
|
||||
// Find the model we just loaded and verify some attributes
|
||||
found := false
|
||||
for _, model := range resp.Models {
|
||||
if strings.Contains(model.Name, req.Model) {
|
||||
found = true
|
||||
if model.Model == "" {
|
||||
t.Errorf("model field omitted: %#v", model)
|
||||
}
|
||||
if model.Size == 0 {
|
||||
t.Errorf("size omitted: %#v", model)
|
||||
}
|
||||
if model.Digest == "" {
|
||||
t.Errorf("digest omitted: %#v", model)
|
||||
}
|
||||
verifyModelDetails(t, model.Details)
|
||||
var nilTime time.Time
|
||||
if model.ExpiresAt == nilTime {
|
||||
t.Errorf("expires_at omitted: %#v", model)
|
||||
}
|
||||
// SizeVRAM could be zero.
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("unable to locate running model: %#v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIChat(t *testing.T) {
|
||||
initialTimeout := 60 * time.Second
|
||||
streamTimeout := 30 * time.Second
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.ChatRequest{
|
||||
Model: smol,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "why is the sky blue? be brief",
|
||||
},
|
||||
},
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
anyResp := []string{"rayleigh", "scattering"}
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
stream bool
|
||||
}{
|
||||
{
|
||||
name: "stream",
|
||||
stream: true,
|
||||
},
|
||||
{
|
||||
name: "no_stream",
|
||||
stream: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var buf bytes.Buffer
|
||||
fn := func(response api.ChatResponse) error {
|
||||
// Fields that must always be present
|
||||
if response.Model == "" {
|
||||
t.Errorf("response missing model: %#v", response)
|
||||
}
|
||||
if response.Done {
|
||||
// Required fields for final updates:
|
||||
var nilTime time.Time
|
||||
if response.CreatedAt == nilTime {
|
||||
t.Errorf("final response missing total_duration: %#v", response)
|
||||
}
|
||||
if response.DoneReason == "" {
|
||||
t.Errorf("final response missing done_reason: %#v", response)
|
||||
}
|
||||
if response.Metrics.TotalDuration == 0 {
|
||||
t.Errorf("final response missing total_duration: %#v", response)
|
||||
}
|
||||
if response.Metrics.LoadDuration == 0 {
|
||||
t.Errorf("final response missing load_duration: %#v", response)
|
||||
}
|
||||
if response.Metrics.PromptEvalDuration == 0 {
|
||||
t.Errorf("final response missing prompt_eval_duration: %#v", response)
|
||||
}
|
||||
if response.Metrics.EvalCount == 0 {
|
||||
t.Errorf("final response missing eval_count: %#v", response)
|
||||
}
|
||||
if response.Metrics.EvalDuration == 0 {
|
||||
t.Errorf("final response missing eval_duration: %#v", response)
|
||||
}
|
||||
|
||||
if response.Metrics.PromptEvalCount == 0 {
|
||||
t.Errorf("final response missing prompt_eval_count: %#v", response)
|
||||
}
|
||||
} // else incremental response, nothing to check right now...
|
||||
buf.Write([]byte(response.Message.Content))
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return fmt.Errorf("stall was detected while streaming response, aborting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
req.Stream = &test.stream
|
||||
req.Options["seed"] = rand.Int() // bust cache for prompt eval results
|
||||
genErr = client.Chat(ctx, &req, fn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
if buf.Len() == 0 {
|
||||
t.Errorf("chat never started. Timed out after :%s", initialTimeout.String())
|
||||
} else {
|
||||
t.Errorf("chat stalled. Response so far:%s", buf.String())
|
||||
}
|
||||
case <-done:
|
||||
if genErr != nil {
|
||||
t.Fatalf("failed with %s request prompt %v", req.Model, req.Messages)
|
||||
}
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Errorf("none of %v found in %s", anyResp, response)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for chat")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIListModels(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Make sure we have at least one model so an empty list can be considered a failure
|
||||
if err := PullIfMissing(ctx, client, smol); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
|
||||
resp, err := client.List(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to list models: %s", err)
|
||||
}
|
||||
if len(resp.Models) == 0 {
|
||||
t.Fatalf("list should not be empty")
|
||||
}
|
||||
model := resp.Models[0]
|
||||
if model.Name == "" {
|
||||
t.Errorf("first model name empty: %#v", model)
|
||||
}
|
||||
var nilTime time.Time
|
||||
if model.ModifiedAt == nilTime {
|
||||
t.Errorf("first model modified_at empty: %#v", model)
|
||||
}
|
||||
if model.Size == 0 {
|
||||
t.Errorf("first model size empty: %#v", model)
|
||||
}
|
||||
if model.Digest == "" {
|
||||
t.Errorf("first model digest empty: %#v", model)
|
||||
}
|
||||
verifyModelDetails(t, model.Details)
|
||||
}
|
||||
|
||||
func verifyModelDetails(t *testing.T, details api.ModelDetails) {
|
||||
if details.Format == "" {
|
||||
t.Errorf("first model details.format empty: %#v", details)
|
||||
}
|
||||
if details.Family == "" {
|
||||
t.Errorf("first model details.family empty: %#v", details)
|
||||
}
|
||||
if details.ParameterSize == "" {
|
||||
t.Errorf("first model details.parameter_size empty: %#v", details)
|
||||
}
|
||||
if details.QuantizationLevel == "" {
|
||||
t.Errorf("first model details.quantization_level empty: %#v", details)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIShowModel(t *testing.T) {
|
||||
modelName := "llama3.2"
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
if err := PullIfMissing(ctx, client, modelName); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Name: modelName})
|
||||
if err != nil {
|
||||
t.Fatalf("unable to show model: %s", err)
|
||||
}
|
||||
if resp.License == "" {
|
||||
t.Errorf("%s missing license: %#v", modelName, resp)
|
||||
}
|
||||
if resp.Modelfile == "" {
|
||||
t.Errorf("%s missing modelfile: %#v", modelName, resp)
|
||||
}
|
||||
if resp.Parameters == "" {
|
||||
t.Errorf("%s missing parameters: %#v", modelName, resp)
|
||||
}
|
||||
if resp.Template == "" {
|
||||
t.Errorf("%s missing template: %#v", modelName, resp)
|
||||
}
|
||||
// llama3 omits system
|
||||
verifyModelDetails(t, resp.Details)
|
||||
// llama3 ommits messages
|
||||
if len(resp.ModelInfo) == 0 {
|
||||
t.Errorf("%s missing model_info: %#v", modelName, resp)
|
||||
}
|
||||
// llama3 omits projectors
|
||||
var nilTime time.Time
|
||||
if resp.ModifiedAt == nilTime {
|
||||
t.Errorf("%s missing modified_at: %#v", modelName, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIEmbeddings(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
req := api.EmbeddingRequest{
|
||||
Model: "orca-mini",
|
||||
Prompt: "why is the sky blue?",
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
|
||||
resp, err := client.Embeddings(ctx, &req)
|
||||
if err != nil {
|
||||
t.Fatalf("embeddings call failed %s", err)
|
||||
}
|
||||
if len(resp.Embedding) == 0 {
|
||||
t.Errorf("zero length embedding response")
|
||||
}
|
||||
}
|
@ -14,15 +14,15 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOrcaMiniBlueSky(t *testing.T) {
|
||||
func TestBlueSky(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: "orca-mini",
|
||||
Model: smol,
|
||||
Prompt: "why is the sky blue?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
@ -31,6 +31,7 @@ func TestOrcaMiniBlueSky(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUnicode(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
@ -39,7 +40,7 @@ func TestUnicode(t *testing.T) {
|
||||
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K",
|
||||
Prompt: "天空为什么是蓝色的?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
// Workaround deepseek context shifting bug
|
||||
@ -61,7 +62,7 @@ func TestExtendedUnicodeOutput(t *testing.T) {
|
||||
Model: "gemma2:2b",
|
||||
Prompt: "Output some smily face emoji",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
@ -93,10 +94,10 @@ func TestUnicodeModelDir(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
req := api.GenerateRequest{
|
||||
Model: "orca-mini",
|
||||
Model: smol,
|
||||
Prompt: "why is the sky blue?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
|
@ -21,11 +21,11 @@ func TestMultiModelConcurrency(t *testing.T) {
|
||||
var (
|
||||
req = [2]api.GenerateRequest{
|
||||
{
|
||||
Model: "orca-mini",
|
||||
Model: "llama3.2:1b",
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
@ -34,7 +34,7 @@ func TestMultiModelConcurrency(t *testing.T) {
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
@ -67,7 +67,7 @@ func TestMultiModelConcurrency(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
|
||||
func TestIntegrationConcurrentPredict(t *testing.T) {
|
||||
req, resp := GenerateRequests()
|
||||
reqLimit := len(req)
|
||||
iterLimit := 5
|
||||
@ -117,6 +117,9 @@ func TestMultiModelStress(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if maxVram < 2*format.GibiByte {
|
||||
t.Skip("VRAM less than 2G, skipping model stress tests")
|
||||
}
|
||||
|
||||
type model struct {
|
||||
name string
|
||||
@ -125,8 +128,8 @@ func TestMultiModelStress(t *testing.T) {
|
||||
|
||||
smallModels := []model{
|
||||
{
|
||||
name: "orca-mini",
|
||||
size: 2992 * format.MebiByte,
|
||||
name: "llama3.2:1b",
|
||||
size: 2876 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "phi",
|
||||
|
@ -23,7 +23,7 @@ func TestLongInputContext(t *testing.T) {
|
||||
Model: "llama2",
|
||||
Prompt: "Oh, don’t speak to me of Austria. Perhaps I don’t understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexander’s loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I don’t believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
"num_ctx": 128,
|
||||
@ -50,7 +50,7 @@ func TestContextExhaustion(t *testing.T) {
|
||||
Model: "llama2",
|
||||
Prompt: "Write me a story with a ton of emojis?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
"num_ctx": 128,
|
||||
|
@ -34,13 +34,15 @@ func cosineSimilarity[V float32 | float64](v1, v2 []V) V {
|
||||
func TestAllMiniLMEmbeddings(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
req := api.EmbeddingRequest{
|
||||
Model: "all-minilm",
|
||||
Prompt: "why is the sky blue?",
|
||||
}
|
||||
|
||||
res, err := embeddingTestHelper(ctx, t, req)
|
||||
res, err := embeddingTestHelper(ctx, client, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
@ -62,13 +64,15 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
|
||||
func TestAllMiniLMEmbed(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
req := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(ctx, t, req)
|
||||
res, err := embedTestHelper(ctx, client, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
@ -98,13 +102,15 @@ func TestAllMiniLMEmbed(t *testing.T) {
|
||||
func TestAllMiniLMBatchEmbed(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
req := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: []string{"why is the sky blue?", "why is the grass green?"},
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(ctx, t, req)
|
||||
res, err := embedTestHelper(ctx, client, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
@ -144,6 +150,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
|
||||
func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
truncTrue, truncFalse := true, false
|
||||
|
||||
@ -182,7 +190,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
res := make(map[string]*api.EmbedResponse)
|
||||
|
||||
for _, req := range reqs {
|
||||
response, err := embedTestHelper(ctx, t, req.Request)
|
||||
response, err := embedTestHelper(ctx, client, t, req.Request)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
@ -198,7 +206,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
}
|
||||
|
||||
// check that truncate set to false returns an error if context length is exceeded
|
||||
_, err := embedTestHelper(ctx, t, api.EmbedRequest{
|
||||
_, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncFalse,
|
||||
@ -210,9 +218,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
||||
}
|
||||
@ -226,9 +232,7 @@ func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingReq
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
||||
}
|
||||
|
@ -12,58 +12,51 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIntegrationLlava(t *testing.T) {
|
||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||
require.NoError(t, err)
|
||||
req := api.GenerateRequest{
|
||||
Model: "llava:7b",
|
||||
Prompt: "what does the text in this image say?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
func TestVisionModels(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
type testCase struct {
|
||||
model string
|
||||
}
|
||||
testCases := []testCase{
|
||||
{
|
||||
model: "llava:7b",
|
||||
},
|
||||
Images: []api.ImageData{
|
||||
image,
|
||||
{
|
||||
model: "llama3.2-vision",
|
||||
},
|
||||
{
|
||||
model: "gemma3",
|
||||
},
|
||||
}
|
||||
|
||||
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
|
||||
resp := "the ollam"
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
// llava models on CPU can be quite slow to start,
|
||||
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
|
||||
}
|
||||
for _, v := range testCases {
|
||||
t.Run(v.model, func(t *testing.T) {
|
||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||
require.NoError(t, err)
|
||||
req := api.GenerateRequest{
|
||||
Model: v.model,
|
||||
Prompt: "what does the text in this image say?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
Images: []api.ImageData{
|
||||
image,
|
||||
},
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
|
||||
func TestIntegrationMllama(t *testing.T) {
|
||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||
require.NoError(t, err)
|
||||
req := api.GenerateRequest{
|
||||
// TODO fix up once we publish the final image
|
||||
Model: "x/llama3.2-vision",
|
||||
Prompt: "what does the text in this image say?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
Images: []api.ImageData{
|
||||
image,
|
||||
},
|
||||
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
|
||||
resp := "the ollam"
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
// llava models on CPU can be quite slow to start
|
||||
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
||||
})
|
||||
}
|
||||
|
||||
resp := "the ollamas"
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
// mllama models on CPU can be quite slow to start,
|
||||
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
||||
}
|
||||
|
||||
func TestIntegrationSplitBatch(t *testing.T) {
|
||||
@ -75,7 +68,7 @@ func TestIntegrationSplitBatch(t *testing.T) {
|
||||
System: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed aliquet, justo in malesuada lobortis, odio ligula volutpat quam, quis faucibus ipsum magna quis sapien. Aliquam in venenatis diam, eu viverra magna. Phasellus imperdiet hendrerit volutpat. Vivamus sem ex, facilisis placerat felis non, dictum elementum est. Phasellus aliquam imperdiet lacus, eget placerat ligula sodales vel. Pellentesque nec auctor mi. Curabitur arcu nisi, faucibus eget nunc id, viverra interdum mi. Curabitur ornare ipsum ex, ac euismod ex aliquam in. Vestibulum id magna at purus accumsan fermentum. Proin scelerisque posuere nunc quis interdum. Maecenas sed mollis nisl. Etiam vitae ipsum interdum, placerat est quis, tincidunt velit. Nullam tempor nibh non lorem volutpat efficitur. Cras laoreet diam imperdiet ipsum auctor bibendum. Suspendisse ultrices urna sed metus sagittis suscipit. Quisque ullamcorper aliquam nibh ut mollis. Aenean dapibus mauris pharetra, venenatis elit ac, hendrerit odio. Cras vestibulum erat tempor, lobortis justo eu, lobortis ipsum. Nam laoreet dapibus sem. Proin vel diam ultrices, elementum ante et, ornare lectus. Proin eu accumsan nisl. Praesent ac ex vitae ipsum vulputate tristique facilisis sit amet lacus. Nullam faucibus magna a pellentesque pretium. Nunc lacinia ullamcorper sollicitudin. Donec vitae accumsan turpis, sed porttitor est. Donec porttitor mi vitae augue faucibus, vel mollis diam tincidunt.",
|
||||
Prompt: "what does the text in this image say?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
|
@ -17,30 +17,30 @@ var (
|
||||
stream = false
|
||||
req = [2]api.GenerateRequest{
|
||||
{
|
||||
Model: "orca-mini",
|
||||
Model: smol,
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "orca-mini",
|
||||
Model: smol,
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
},
|
||||
}
|
||||
resp = [2][]string{
|
||||
{"sunlight"},
|
||||
{"sunlight", "scattering", "interact"},
|
||||
{"england", "english", "massachusetts", "pilgrims"},
|
||||
}
|
||||
)
|
||||
|
||||
func TestIntegrationSimpleOrcaMini(t *testing.T) {
|
||||
func TestIntegrationSimple(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
|
||||
defer cancel()
|
||||
GenerateTestHelper(ctx, t, req[0], resp[0])
|
||||
|
@ -30,9 +30,9 @@ func TestMaxQueue(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MAX_QUEUE", strconv.Itoa(threadCount))
|
||||
|
||||
req := api.GenerateRequest{
|
||||
Model: "orca-mini",
|
||||
Model: smol,
|
||||
Prompt: "write a long historical fiction story about christopher columbus. use at least 10 facts from his actual journey",
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
@ -52,8 +52,8 @@ func TestMaxQueue(t *testing.T) {
|
||||
embedCtx := ctx
|
||||
|
||||
var genwg sync.WaitGroup
|
||||
genwg.Add(1)
|
||||
go func() {
|
||||
genwg.Add(1)
|
||||
defer genwg.Done()
|
||||
slog.Info("Starting generate request")
|
||||
DoGenerate(ctx, t, client, req, resp, 45*time.Second, 5*time.Second)
|
||||
@ -61,7 +61,7 @@ func TestMaxQueue(t *testing.T) {
|
||||
}()
|
||||
|
||||
// Give the generate a chance to get started before we start hammering on embed requests
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
threadCount += 10 // Add a few extra to ensure we push the queue past its limit
|
||||
busyCount := 0
|
||||
@ -71,8 +71,8 @@ func TestMaxQueue(t *testing.T) {
|
||||
counterMu := sync.Mutex{}
|
||||
var embedwg sync.WaitGroup
|
||||
for i := 0; i < threadCount; i++ {
|
||||
embedwg.Add(1)
|
||||
go func(i int) {
|
||||
embedwg.Add(1)
|
||||
defer embedwg.Done()
|
||||
slog.Info("embed started", "id", i)
|
||||
embedReq := api.EmbeddingRequest{
|
||||
|
184
integration/model_arch_test.go
Normal file
184
integration/model_arch_test.go
Normal file
@ -0,0 +1,184 @@
|
||||
//go:build integration && models
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
var (
|
||||
started = time.Now()
|
||||
chatModels = []string{
|
||||
"granite3-moe:latest",
|
||||
"granite-code:latest",
|
||||
"nemotron-mini:latest",
|
||||
"command-r:latest",
|
||||
"gemma2:latest",
|
||||
"gemma:latest",
|
||||
"internlm2:latest",
|
||||
"phi3.5:latest",
|
||||
"phi3:latest",
|
||||
// "phi:latest", // flaky, sometimes generates no response on first query
|
||||
"stablelm2:latest", // Predictions are off, crashes on small VRAM GPUs
|
||||
"falcon:latest",
|
||||
"falcon2:latest",
|
||||
"minicpm-v:latest",
|
||||
"mistral:latest",
|
||||
"orca-mini:latest",
|
||||
"llama2:latest",
|
||||
"llama3.1:latest",
|
||||
"llama3.2:latest",
|
||||
"llama3.2-vision:latest",
|
||||
"qwen2.5-coder:latest",
|
||||
"qwen:latest",
|
||||
"solar-pro:latest",
|
||||
}
|
||||
)
|
||||
|
||||
func TestModelsGenerate(t *testing.T) {
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// TODO use info API eventually
|
||||
var maxVram uint64
|
||||
var err error
|
||||
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||
maxVram, err = strconv.ParseUint(s, 10, 64)
|
||||
if err != nil {
|
||||
t.Fatalf("invalid OLLAMA_MAX_VRAM %v", err)
|
||||
}
|
||||
} else {
|
||||
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
|
||||
}
|
||||
|
||||
for _, model := range chatModels {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
t.Skip("skipping remaining tests to avoid excessive runtime")
|
||||
}
|
||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
if maxVram > 0 {
|
||||
resp, err := client.List(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("list models failed %v", err)
|
||||
}
|
||||
for _, m := range resp.Models {
|
||||
if m.Name == model && float32(m.Size)*1.2 > float32(maxVram) {
|
||||
t.Skipf("model %s is too large for available VRAM: %s > %s", model, format.HumanBytes(m.Size), format.HumanBytes(int64(maxVram)))
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO - fiddle with context size
|
||||
req := api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: "why is the sky blue?",
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
anyResp := []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}
|
||||
DoGenerate(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelsEmbed(t *testing.T) {
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// TODO use info API eventually
|
||||
var maxVram uint64
|
||||
var err error
|
||||
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||
maxVram, err = strconv.ParseUint(s, 10, 64)
|
||||
if err != nil {
|
||||
t.Fatalf("invalid OLLAMA_MAX_VRAM %v", err)
|
||||
}
|
||||
} else {
|
||||
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
|
||||
}
|
||||
|
||||
data, err := ioutil.ReadFile(filepath.Join("testdata", "embed.json"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open test data file: %s", err)
|
||||
}
|
||||
testCase := map[string][]float64{}
|
||||
err = json.Unmarshal(data, &testCase)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load test data: %s", err)
|
||||
}
|
||||
for model, expected := range testCase {
|
||||
|
||||
t.Run(model, func(t *testing.T) {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
t.Skip("skipping remaining tests to avoid excessive runtime")
|
||||
}
|
||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
if maxVram > 0 {
|
||||
resp, err := client.List(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("list models failed %v", err)
|
||||
}
|
||||
for _, m := range resp.Models {
|
||||
if m.Name == model && float32(m.Size)*1.2 > float32(maxVram) {
|
||||
t.Skipf("model %s is too large for available VRAM: %s > %s", model, format.HumanBytes(m.Size), format.HumanBytes(int64(maxVram)))
|
||||
}
|
||||
}
|
||||
}
|
||||
req := api.EmbeddingRequest{
|
||||
Model: model,
|
||||
Prompt: "why is the sky blue?",
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
resp, err := client.Embeddings(ctx, &req)
|
||||
if err != nil {
|
||||
t.Fatalf("embeddings call failed %s", err)
|
||||
}
|
||||
if len(resp.Embedding) == 0 {
|
||||
t.Errorf("zero length embedding response")
|
||||
}
|
||||
if len(expected) != len(resp.Embedding) {
|
||||
expStr := make([]string, len(resp.Embedding))
|
||||
for i, v := range resp.Embedding {
|
||||
expStr[i] = fmt.Sprintf("%0.6f", v)
|
||||
}
|
||||
// When adding new models, use this output to populate the testdata/embed.json
|
||||
fmt.Printf("expected\n%s\n", strings.Join(expStr, ", "))
|
||||
t.Fatalf("expected %d, got %d", len(expected), len(resp.Embedding))
|
||||
}
|
||||
sim := cosineSimilarity(resp.Embedding, expected)
|
||||
if sim < 0.99 {
|
||||
t.Fatalf("expected %v, got %v (similarity: %f)", expected[0:5], resp.Embedding[0:5], sim)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
130
integration/quantization_test.go
Normal file
130
integration/quantization_test.go
Normal file
@ -0,0 +1,130 @@
|
||||
//go:build integration && models
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestQuantization(t *testing.T) {
|
||||
sourceModels := []string{
|
||||
"qwen2.5:0.5b-instruct-fp16",
|
||||
}
|
||||
quantizations := []string{
|
||||
"Q8_0",
|
||||
"Q4_K_S",
|
||||
"Q4_K_M",
|
||||
"Q4_K",
|
||||
}
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
started := time.Now()
|
||||
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
for _, base := range sourceModels {
|
||||
if err := PullIfMissing(ctx, client, base); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
for _, quant := range quantizations {
|
||||
newName := fmt.Sprintf("%s__%s", base, quant)
|
||||
t.Run(newName, func(t *testing.T) {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
t.Skip("skipping remaining tests to avoid excessive runtime")
|
||||
}
|
||||
req := &api.CreateRequest{
|
||||
Model: newName,
|
||||
Quantization: quant,
|
||||
From: base,
|
||||
}
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
// fmt.Print(".")
|
||||
return nil
|
||||
}
|
||||
t.Logf("quantizing: %s -> %s", base, quant)
|
||||
if err := client.Create(ctx, req, fn); err != nil {
|
||||
t.Fatalf("create failed %s", err)
|
||||
}
|
||||
defer func() {
|
||||
req := &api.DeleteRequest{
|
||||
Model: newName,
|
||||
}
|
||||
t.Logf("deleting: %s -> %s", base, quant)
|
||||
if err := client.Delete(ctx, req); err != nil {
|
||||
t.Logf("failed to clean up %s: %s", req.Model, err)
|
||||
}
|
||||
}()
|
||||
// Check metadata on the model
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Name: newName})
|
||||
if err != nil {
|
||||
t.Fatalf("unable to show model: %s", err)
|
||||
}
|
||||
if !strings.Contains(resp.Details.QuantizationLevel, quant) {
|
||||
t.Fatalf("unexpected quantization for %s:\ngot: %s", newName, resp.Details.QuantizationLevel)
|
||||
}
|
||||
|
||||
stream := true
|
||||
genReq := api.GenerateRequest{
|
||||
Model: newName,
|
||||
Prompt: "why is the sky blue?",
|
||||
KeepAlive: &api.Duration{Duration: 3 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
Stream: &stream,
|
||||
}
|
||||
t.Logf("verifying: %s -> %s", base, quant)
|
||||
|
||||
// Some smaller quantizations can cause models to have poor quality
|
||||
// or get stuck in repetition loops, so we stop as soon as we have any matches
|
||||
anyResp := []string{"rayleigh", "scattering", "day", "sun", "moon", "color", "nitrogen", "oxygen"}
|
||||
reqCtx, reqCancel := context.WithCancel(ctx)
|
||||
atLeastOne := false
|
||||
var buf bytes.Buffer
|
||||
genfn := func(response api.GenerateResponse) error {
|
||||
buf.Write([]byte(response.Response))
|
||||
fullResp := strings.ToLower(buf.String())
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(fullResp, resp) {
|
||||
atLeastOne = true
|
||||
t.Log(fullResp)
|
||||
reqCancel()
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
genErr = client.Generate(reqCtx, &genReq, genfn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
if genErr != nil && !atLeastOne {
|
||||
t.Fatalf("failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
}
|
||||
|
||||
t.Logf("passed")
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
21
integration/testdata/embed.json
vendored
Normal file
21
integration/testdata/embed.json
vendored
Normal file
File diff suppressed because one or more lines are too long
@ -24,9 +24,14 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/app/lifecycle"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
smol = "llama3.2:1b"
|
||||
)
|
||||
|
||||
func Init() {
|
||||
lifecycle.InitLogging()
|
||||
}
|
||||
@ -140,7 +145,7 @@ func PullIfMissing(ctx context.Context, client *api.Client, modelName string) er
|
||||
|
||||
showCtx, cancel := context.WithDeadlineCause(
|
||||
ctx,
|
||||
time.Now().Add(10*time.Second),
|
||||
time.Now().Add(20*time.Second),
|
||||
fmt.Errorf("show for existing model %s took too long", modelName),
|
||||
)
|
||||
defer cancel()
|
||||
@ -157,7 +162,7 @@ func PullIfMissing(ctx context.Context, client *api.Client, modelName string) er
|
||||
}
|
||||
slog.Info("model missing", "model", modelName)
|
||||
|
||||
stallDuration := 30 * time.Second // This includes checksum verification, which can take a while on larger models
|
||||
stallDuration := 60 * time.Second // This includes checksum verification, which can take a while on larger models, and slower systems
|
||||
stallTimer := time.NewTimer(stallDuration)
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
// fmt.Print(".")
|
||||
@ -212,6 +217,7 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
|
||||
slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
return
|
||||
}
|
||||
defer fp.Close()
|
||||
data, err := io.ReadAll(fp)
|
||||
if err != nil {
|
||||
slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
@ -283,51 +289,51 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
||||
}
|
||||
|
||||
// Generate a set of requests
|
||||
// By default each request uses orca-mini as the model
|
||||
// By default each request uses llama3.2 as the model
|
||||
func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
return []api.GenerateRequest{
|
||||
{
|
||||
Model: "orca-mini",
|
||||
Model: smol,
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "orca-mini",
|
||||
Model: smol,
|
||||
Prompt: "why is the color of dirt brown?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "orca-mini",
|
||||
Model: smol,
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "orca-mini",
|
||||
Model: smol,
|
||||
Prompt: "what is the origin of independence day?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "orca-mini",
|
||||
Model: smol,
|
||||
Prompt: "what is the composition of air?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
@ -341,3 +347,26 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
{"nitrogen", "oxygen", "carbon", "dioxide"},
|
||||
}
|
||||
}
|
||||
|
||||
func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
||||
// TODO use info API in the future
|
||||
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||
maxVram, err := strconv.ParseUint(s, 10, 64)
|
||||
require.NoError(t, err)
|
||||
// Don't hammer on small VRAM cards...
|
||||
if maxVram < gb*format.GibiByte {
|
||||
t.Skip("skipping with small VRAM to avoid timeouts")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) {
|
||||
deadline, hasDeadline := t.Deadline()
|
||||
if !hasDeadline {
|
||||
return 8 * time.Minute, 10 * time.Minute
|
||||
} else if deadline.Compare(time.Now().Add(2*time.Minute)) <= 0 {
|
||||
t.Skip("too little time")
|
||||
return time.Duration(0), time.Duration(0)
|
||||
}
|
||||
return -time.Since(deadline.Add(-2 * time.Minute)), -time.Since(deadline.Add(-20 * time.Second))
|
||||
}
|
||||
|
@ -43,20 +43,31 @@ type Cache interface {
|
||||
|
||||
// ** cache management **
|
||||
|
||||
// Init sets up runtime parameters
|
||||
Init(backend ml.Backend, dtype ml.DType, capacity int32)
|
||||
// Init sets up runtime parameters.
|
||||
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
|
||||
// dtype: The data type for storing cache entries
|
||||
// maxSequences: The maximum number of sequences stored in the cache - across all batches
|
||||
// capacity: The number of cache entries to store, per sequence
|
||||
// maxBatch: The maximum number of tokens that can occur in a single batch
|
||||
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
|
||||
|
||||
// Close closes the cache and frees resources associated with it
|
||||
Close()
|
||||
|
||||
// StartForward is called before the start of the model's forward pass.
|
||||
// For each token in the coming batch, there must be a corresponding
|
||||
// entry in positions and seqs.
|
||||
StartForward(ctx ml.Context, opts input.Options) error
|
||||
// entry in positions and seqs. reserve is to preallocate memory
|
||||
// without actually storing data in the cache.
|
||||
StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
|
||||
|
||||
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
|
||||
CopyPrefix(srcSeq, dstSeq int, len int32)
|
||||
|
||||
// CanResume returns true if the cache can continue with the next token at
|
||||
// the given position and sequence. Assumes that the caller has already
|
||||
// verified the contents of the cache.
|
||||
CanResume(seq int, pos int32) bool
|
||||
|
||||
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
|
||||
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
|
||||
//
|
||||
|
@ -20,8 +20,8 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
|
||||
// The mask is of shape history size, batch size
|
||||
type Causal struct {
|
||||
DType ml.DType
|
||||
Capacity int32
|
||||
windowSize int32
|
||||
chunkSize int32
|
||||
|
||||
opts CausalOptions
|
||||
|
||||
@ -98,7 +98,18 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
||||
func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
|
||||
return &Causal{
|
||||
windowSize: math.MaxInt32,
|
||||
chunkSize: chunkSize,
|
||||
shiftFn: shift,
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
if c.config == nil {
|
||||
var config ml.CacheConfig
|
||||
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||
@ -119,9 +130,16 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
||||
c.config.MaskDType = ml.DTypeF32
|
||||
}
|
||||
|
||||
var cacheSize int
|
||||
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize) {
|
||||
cacheSize = maxSequences * capacity
|
||||
} else {
|
||||
cacheSize = (maxSequences * int(c.windowSize)) + maxBatch
|
||||
}
|
||||
cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
||||
c.cells = make([]cacheCell, cacheSize)
|
||||
|
||||
c.DType = dtype
|
||||
c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
|
||||
c.cells = make([]cacheCell, c.Capacity)
|
||||
c.cellRanges = make(map[int]cellRange)
|
||||
c.backend = backend
|
||||
}
|
||||
@ -140,49 +158,60 @@ func (c *Causal) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
|
||||
c.curBatchSize = len(opts.Positions)
|
||||
c.curSequences = opts.Sequences
|
||||
c.curPositions = opts.Positions
|
||||
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
c.curBatchSize = len(batch.Positions)
|
||||
c.curSequences = batch.Sequences
|
||||
c.curPositions = batch.Positions
|
||||
c.opts.Except = nil
|
||||
|
||||
var err error
|
||||
c.curLoc, err = c.findStartLoc()
|
||||
if errors.Is(err, ErrKvCacheFull) {
|
||||
c.defrag()
|
||||
if !reserve {
|
||||
c.updateSlidingWindow()
|
||||
|
||||
var err error
|
||||
c.curLoc, err = c.findStartLoc()
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.curCellRange = newRange()
|
||||
for i, pos := range opts.Positions {
|
||||
seq := opts.Sequences[i]
|
||||
|
||||
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||
|
||||
seqRange, ok := c.cellRanges[seq]
|
||||
if !ok {
|
||||
seqRange = newRange()
|
||||
}
|
||||
|
||||
if c.curLoc+i > seqRange.max {
|
||||
seqRange.max = c.curLoc + i
|
||||
}
|
||||
if seqRange.max > c.curCellRange.max {
|
||||
c.curCellRange.max = seqRange.max
|
||||
}
|
||||
|
||||
if c.curLoc+i < seqRange.min {
|
||||
seqRange.min = c.curLoc + i
|
||||
}
|
||||
if seqRange.min < c.curCellRange.min {
|
||||
c.curCellRange.min = seqRange.min
|
||||
}
|
||||
c.cellRanges[seq] = seqRange
|
||||
if errors.Is(err, ErrKvCacheFull) {
|
||||
c.defrag()
|
||||
c.curLoc, err = c.findStartLoc()
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.curCellRange = newRange()
|
||||
for i, pos := range batch.Positions {
|
||||
seq := batch.Sequences[i]
|
||||
|
||||
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||
|
||||
seqRange, ok := c.cellRanges[seq]
|
||||
if !ok {
|
||||
seqRange = newRange()
|
||||
}
|
||||
|
||||
if c.curLoc+i > seqRange.max {
|
||||
seqRange.max = c.curLoc + i
|
||||
}
|
||||
if seqRange.max > c.curCellRange.max {
|
||||
c.curCellRange.max = seqRange.max
|
||||
}
|
||||
|
||||
if c.curLoc+i < seqRange.min {
|
||||
seqRange.min = c.curLoc + i
|
||||
}
|
||||
if seqRange.min < c.curCellRange.min {
|
||||
c.curCellRange.min = seqRange.min
|
||||
}
|
||||
c.cellRanges[seq] = seqRange
|
||||
}
|
||||
} else {
|
||||
// If we are reserving memory, don't update any of the cache metadata but set the size
|
||||
// to the worst case.
|
||||
c.curLoc = 0
|
||||
c.curCellRange.min = 0
|
||||
c.curCellRange.max = len(c.cells) - 1
|
||||
}
|
||||
|
||||
var err error
|
||||
c.curMask, err = c.buildMask(ctx)
|
||||
|
||||
return err
|
||||
@ -210,7 +239,51 @@ func (c *Causal) findStartLoc() (int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
|
||||
return 0, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
|
||||
}
|
||||
|
||||
func (c *Causal) updateSlidingWindow() {
|
||||
if c.windowSize == math.MaxInt32 {
|
||||
return
|
||||
}
|
||||
|
||||
// create a map of unique sequences to the lowest position in that sequence
|
||||
lowestPos := make(map[int]int32)
|
||||
for i := range c.curPositions {
|
||||
seq := c.curSequences[i]
|
||||
|
||||
pos, ok := lowestPos[seq]
|
||||
if !ok {
|
||||
pos = c.curPositions[i]
|
||||
} else if c.curPositions[i] < pos {
|
||||
pos = c.curPositions[i]
|
||||
}
|
||||
|
||||
lowestPos[seq] = pos
|
||||
}
|
||||
|
||||
// delete any entries that are beyond the window of the oldest position in the sequence
|
||||
for seq, pos := range lowestPos {
|
||||
oldRange, ok := c.cellRanges[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
newRange := newRange()
|
||||
|
||||
for i := oldRange.min; i <= oldRange.max; i++ {
|
||||
if slices.Contains(c.cells[i].sequences, seq) {
|
||||
if c.cells[i].pos < pos-c.windowSize {
|
||||
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||
} else {
|
||||
newRange.min = min(newRange.min, i)
|
||||
newRange.max = max(newRange.max, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.cellRanges[seq] = newRange
|
||||
}
|
||||
}
|
||||
|
||||
func roundDown(length, pad int) int {
|
||||
@ -239,6 +312,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
||||
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
||||
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
||||
(enabled && c.cells[j].pos > c.curPositions[i]) ||
|
||||
c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
|
||||
c.cells[j].pos < c.curPositions[i]-c.windowSize {
|
||||
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||
}
|
||||
@ -265,7 +339,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
||||
return maskTensor, nil
|
||||
}
|
||||
|
||||
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
||||
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
|
||||
for i, key := range c.keys {
|
||||
if key == nil {
|
||||
continue
|
||||
@ -275,8 +349,8 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(2)
|
||||
|
||||
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
|
||||
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
|
||||
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
|
||||
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
|
||||
|
||||
value := c.values[i]
|
||||
var vSrcView, vDstView ml.Tensor
|
||||
@ -284,14 +358,14 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
||||
vHeadDim := value.Dim(1)
|
||||
elemSize := value.Stride(0)
|
||||
|
||||
vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
||||
vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
||||
vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
|
||||
vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
|
||||
} else {
|
||||
vHeadDim := value.Dim(0)
|
||||
rowSize := value.Stride(2)
|
||||
|
||||
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
|
||||
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
|
||||
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
|
||||
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
|
||||
}
|
||||
|
||||
ctx.Forward(
|
||||
@ -321,7 +395,8 @@ func (c *Causal) defrag() {
|
||||
ctx := c.backend.NewContext()
|
||||
|
||||
// For every move, 6 tensors are required per layer (2 views and a
|
||||
// copy for each of k and v).
|
||||
// copy for each of k and v). We also need to refer to the original
|
||||
// k and v cache tensors - once per layer, not per move.
|
||||
layers := 0
|
||||
for _, key := range c.keys {
|
||||
if key == nil {
|
||||
@ -330,7 +405,7 @@ func (c *Causal) defrag() {
|
||||
layers++
|
||||
}
|
||||
|
||||
maxMoves := ctx.MaxGraphNodes() / (6 * layers)
|
||||
maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
|
||||
moves := 0
|
||||
|
||||
var pendingSrc, pendingDst, pendingLen int
|
||||
@ -479,14 +554,14 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
}
|
||||
|
||||
if _, ok := c.keys[c.curLayer]; !ok {
|
||||
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
|
||||
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, len(c.cells))
|
||||
}
|
||||
|
||||
if _, ok := c.values[c.curLayer]; !ok {
|
||||
if c.config.PermutedV {
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vHeadDim, numKVHeads)
|
||||
} else {
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, len(c.cells))
|
||||
}
|
||||
}
|
||||
|
||||
@ -497,7 +572,7 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
elemSize := c.values[c.curLayer].Stride(0)
|
||||
|
||||
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
|
||||
} else {
|
||||
rowSize := c.values[c.curLayer].Stride(2)
|
||||
|
||||
@ -528,6 +603,35 @@ func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
c.cellRanges[dstSeq] = seqRange
|
||||
}
|
||||
|
||||
func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||
if c.windowSize == math.MaxInt32 {
|
||||
return true
|
||||
}
|
||||
|
||||
seqRange, ok := c.cellRanges[seq]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// for sliding window, check that the window of the new sequence is contained in
|
||||
// the window of what we are storing
|
||||
var last int32 = -1
|
||||
for i := seqRange.min; i <= seqRange.max; i++ {
|
||||
if slices.Contains(c.cells[i].sequences, seq) {
|
||||
last = max(last, c.cells[i].pos)
|
||||
}
|
||||
}
|
||||
|
||||
if last == -1 {
|
||||
return false
|
||||
}
|
||||
|
||||
lastWindowStart := max(0, last-c.windowSize)
|
||||
posWindowStart := max(0, pos-c.windowSize)
|
||||
|
||||
return posWindowStart >= lastWindowStart
|
||||
}
|
||||
|
||||
func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||
if c.shiftFn == nil {
|
||||
return ErrNotSupported
|
||||
@ -582,6 +686,12 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||
}
|
||||
|
||||
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
// TODO(jessegross): We should check to see if removing the middle of the sequence will
|
||||
// cause the sliding window to encompass tokens that we no longer have. If so, then we
|
||||
// should return an error, which will trigger the runner to evaluate the full history and
|
||||
// rebuild the window. However, if we have multimodal inputs in our history, this reuse
|
||||
// results in use after free, so we don't do it for now.
|
||||
|
||||
var offset int32
|
||||
if endIndex != math.MaxInt32 {
|
||||
offset = beginIndex - endIndex
|
||||
@ -596,8 +706,7 @@ func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
} else {
|
||||
if c.cells[i].pos >= endIndex {
|
||||
if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
|
||||
// TODO(jessegross): Need to be careful about data shared between sequences
|
||||
return errors.New("shifting on cells shared by multiple sequences not yet implemented")
|
||||
return errors.New("shifting cells shared by multiple sequences not supported")
|
||||
}
|
||||
|
||||
c.cells[i].pos += offset
|
||||
|
@ -25,7 +25,7 @@ func TestStore(t *testing.T) {
|
||||
cache := NewCausalCache(nil)
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 16)
|
||||
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
@ -58,11 +58,11 @@ func TestSWA(t *testing.T) {
|
||||
cache := NewSWACache(1, nil)
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF32, 16)
|
||||
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "SlidingWindow",
|
||||
name: "FirstBatch",
|
||||
in: []float32{1, 2, 3, 4},
|
||||
inShape: []int{1, 1, 4},
|
||||
seqs: []int{0, 0, 0, 0},
|
||||
@ -71,17 +71,85 @@ func TestSWA(t *testing.T) {
|
||||
expectedShape: []int{1, 1, 4},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||
},
|
||||
{
|
||||
name: "SecondBatch",
|
||||
in: []float32{5, 6},
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{0, 0},
|
||||
pos: []int32{4, 5},
|
||||
expected: []float32{5, 6, 3, 4},
|
||||
expectedShape: []int{1, 1, 4},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))},
|
||||
},
|
||||
}
|
||||
|
||||
testCache(t, backend, cache, tests)
|
||||
}
|
||||
|
||||
func TestChunkedAttention(t *testing.T) {
|
||||
cache := NewChunkedAttentionCache(2, nil)
|
||||
defer cache.Close()
|
||||
|
||||
var b testBackend
|
||||
cache.Init(&b, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
x := float32(math.Inf(-1))
|
||||
|
||||
testCache(
|
||||
t, &b, cache,
|
||||
[]testCase{
|
||||
{
|
||||
name: "FirstBatch",
|
||||
in: []float32{1, 2, 3, 4},
|
||||
inShape: []int{1, 1, 4},
|
||||
seqs: []int{0, 0, 0, 0},
|
||||
pos: []int32{0, 1, 2, 3},
|
||||
expected: []float32{1, 2, 3, 4},
|
||||
expectedShape: []int{1, 1, 4},
|
||||
expectedMask: []float32{
|
||||
0, x, x, x,
|
||||
0, 0, x, x,
|
||||
x, x, 0, x,
|
||||
x, x, 0, 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SecondBatch",
|
||||
in: []float32{5, 6, 7},
|
||||
inShape: []int{1, 1, 3},
|
||||
seqs: []int{0, 0, 0},
|
||||
pos: []int32{4, 5, 6},
|
||||
expected: []float32{1, 2, 3, 4, 5, 6, 7},
|
||||
expectedShape: []int{1, 1, 7},
|
||||
expectedMask: []float32{
|
||||
x, x, x, x, 0, x, x,
|
||||
x, x, x, x, 0, 0, x,
|
||||
x, x, x, x, x, x, 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ThirdBatch",
|
||||
in: []float32{8, 9},
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{0, 0},
|
||||
pos: []int32{7, 8},
|
||||
expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
|
||||
expectedShape: []int{1, 1, 9},
|
||||
expectedMask: []float32{
|
||||
x, x, x, x, x, x, 0, 0, x,
|
||||
x, x, x, x, x, x, x, x, 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func TestSequences(t *testing.T) {
|
||||
backend := &testBackend{}
|
||||
cache := NewCausalCache(nil)
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 16)
|
||||
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
@ -116,7 +184,7 @@ func TestRemove(t *testing.T) {
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 16)
|
||||
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
@ -181,7 +249,7 @@ func TestDefrag(t *testing.T) {
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 16)
|
||||
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
@ -229,7 +297,7 @@ func TestCopy(t *testing.T) {
|
||||
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 16)
|
||||
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
@ -270,7 +338,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
|
||||
context := backend.NewContext()
|
||||
defer context.Close()
|
||||
|
||||
err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs})
|
||||
err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@ -283,21 +351,94 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
|
||||
|
||||
context.Forward(out, mask).Compute(out, mask)
|
||||
|
||||
if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) {
|
||||
t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask)
|
||||
if !slices.Equal(out.Floats(), test.expected) {
|
||||
t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected)
|
||||
}
|
||||
|
||||
if !slices.Equal(out.Shape(), test.expectedShape) {
|
||||
t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape)
|
||||
}
|
||||
|
||||
if !slices.Equal(mask.Floats(), test.expectedMask) {
|
||||
t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type testBackend struct{}
|
||||
func TestCanResume(t *testing.T) {
|
||||
backend := &testBackend{}
|
||||
windowSize := int32(4)
|
||||
cache := NewSWACache(windowSize, nil)
|
||||
defer cache.Close()
|
||||
|
||||
func (b *testBackend) Config() ml.Config {
|
||||
panic("not implemented")
|
||||
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
context := backend.NewContext()
|
||||
defer context.Close()
|
||||
|
||||
err := cache.StartForward(context, input.Batch{
|
||||
Positions: []int32{0, 1, 2, 3},
|
||||
Sequences: []int{0, 0, 0, 0},
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("StartForward failed: %v", err)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor, _ := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// with window size 4, nothing has slid out of the window yet
|
||||
if !cache.CanResume(0, 0) {
|
||||
t.Errorf("CanResume(0, 0) = false, want true (within window)")
|
||||
}
|
||||
if !cache.CanResume(0, 1) {
|
||||
t.Errorf("CanResume(0, 1) = false, want true (within window)")
|
||||
}
|
||||
if !cache.CanResume(0, 2) {
|
||||
t.Errorf("CanResume(0, 2) = false, want true (within window)")
|
||||
}
|
||||
if !cache.CanResume(0, 3) {
|
||||
t.Errorf("CanResume(0, 3) = false, want true (latest position)")
|
||||
}
|
||||
|
||||
// shift window by adding position 4
|
||||
err = cache.StartForward(context, input.Batch{
|
||||
Positions: []int32{4, 5},
|
||||
Sequences: []int{0, 0},
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("StartForward failed: %v", err)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor, _ = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// only the latest position has overlapping windows
|
||||
if cache.CanResume(0, 0) {
|
||||
t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
|
||||
}
|
||||
if cache.CanResume(0, 1) {
|
||||
t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
|
||||
}
|
||||
if cache.CanResume(0, 2) {
|
||||
t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
|
||||
}
|
||||
if cache.CanResume(0, 3) {
|
||||
t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
|
||||
}
|
||||
if cache.CanResume(0, 4) {
|
||||
t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
|
||||
}
|
||||
if !cache.CanResume(0, 5) {
|
||||
t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
|
||||
}
|
||||
}
|
||||
|
||||
func (b *testBackend) Get(name string) ml.Tensor {
|
||||
panic("not implemented")
|
||||
type testBackend struct {
|
||||
ml.Backend
|
||||
}
|
||||
|
||||
func (b *testBackend) NewContext() ml.Context {
|
||||
@ -308,12 +449,10 @@ func (b *testBackend) NewContextSize(int) ml.Context {
|
||||
return &testContext{}
|
||||
}
|
||||
|
||||
func (b *testBackend) SystemInfo() string {
|
||||
return "not implemented"
|
||||
type testContext struct {
|
||||
ml.Context
|
||||
}
|
||||
|
||||
type testContext struct{}
|
||||
|
||||
func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
total := 0
|
||||
|
||||
@ -351,14 +490,26 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
|
||||
s := make([]float32, 0, int((stop-start)/step))
|
||||
for i := start; i < stop; i += step {
|
||||
s = append(s, i)
|
||||
}
|
||||
|
||||
out, _ := c.FromFloatSlice(s, len(s))
|
||||
out.(*testTensor).dtype = dtype
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *testContext) Input() ml.Context { return c }
|
||||
func (c *testContext) Output() ml.Context { return c }
|
||||
func (c *testContext) Layer(int) ml.Context { return c }
|
||||
|
||||
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
||||
|
||||
func (c *testContext) Compute(...ml.Tensor) {}
|
||||
|
||||
func (c *testContext) Reserve() error { return nil }
|
||||
|
||||
func (c *testContext) MaxGraphNodes() int {
|
||||
return 10
|
||||
}
|
||||
@ -366,6 +517,8 @@ func (c *testContext) MaxGraphNodes() int {
|
||||
func (c *testContext) Close() {}
|
||||
|
||||
type testTensor struct {
|
||||
ml.Tensor
|
||||
|
||||
dtype ml.DType
|
||||
elementSize int
|
||||
data []float32
|
||||
@ -393,16 +546,20 @@ func (t *testTensor) DType() ml.DType {
|
||||
return t.dtype
|
||||
}
|
||||
|
||||
func (t *testTensor) Bytes() []byte {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Floats() []float32 {
|
||||
out := make([]float32, len(t.data))
|
||||
copy(out, t.data)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *testTensor) Neg(ctx ml.Context) ml.Tensor {
|
||||
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
|
||||
for i := range out.data {
|
||||
out.data[i] = -t.data[i]
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
|
||||
|
||||
@ -413,66 +570,6 @@ func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *testTensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Softmax(ctx ml.Context) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) LayerNorm(ctx ml.Context, weight, bias ml.Tensor, eps float32) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) RMSNorm(ctx ml.Context, weight ml.Tensor, eps float32) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) GELU(ctx ml.Context) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) SILU(ctx ml.Context) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
offset /= t.elementSize
|
||||
|
||||
@ -495,38 +592,6 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
return view
|
||||
}
|
||||
|
||||
func (t *testTensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
copy(t2.(*testTensor).data, t.data)
|
||||
return nil
|
||||
|
@ -27,6 +27,11 @@ type EncoderCache struct {
|
||||
// anything will be stored)
|
||||
curPos int32
|
||||
|
||||
// curReserve indicates that this forward pass is only for
|
||||
// memory reservation and we should not update our metadata
|
||||
// based on it.
|
||||
curReserve bool
|
||||
|
||||
// ** cache metadata **
|
||||
|
||||
// was something stored in the cache?
|
||||
@ -49,7 +54,7 @@ func NewEncoderCache() *EncoderCache {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
||||
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
if c.config == nil {
|
||||
var config ml.CacheConfig
|
||||
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||
@ -58,6 +63,10 @@ func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)
|
||||
c.config = &config
|
||||
}
|
||||
|
||||
if maxSequences > 1 {
|
||||
panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
|
||||
}
|
||||
|
||||
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
|
||||
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
||||
}
|
||||
@ -79,12 +88,14 @@ func (c *EncoderCache) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *EncoderCache) StartForward(ctx ml.Context, opts input.Options) error {
|
||||
func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
// We work with the most recent image
|
||||
if len(opts.Multimodal) > 0 {
|
||||
c.curPos = opts.Positions[opts.Multimodal[len(opts.Multimodal)-1].Index]
|
||||
if len(batch.Multimodal) > 0 {
|
||||
c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
|
||||
}
|
||||
|
||||
c.curReserve = reserve
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -101,8 +112,10 @@ func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
}
|
||||
|
||||
func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
c.encoderPos = c.curPos
|
||||
c.encoderCached = true
|
||||
if !c.curReserve {
|
||||
c.encoderPos = c.curPos
|
||||
c.encoderCached = true
|
||||
}
|
||||
|
||||
if c.config.PermutedV {
|
||||
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||
@ -130,6 +143,10 @@ func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
panic("encoder cache does not support multiple sequences")
|
||||
}
|
||||
|
||||
func (c *EncoderCache) CanResume(seq int, pos int32) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
|
||||
c.encoderCached = false
|
||||
|
@ -23,9 +23,9 @@ func NewWrapperCache(caches ...Cache) *WrapperCache {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
||||
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
for _, cache := range c.caches {
|
||||
cache.Init(backend, dtype, capacity)
|
||||
cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||
}
|
||||
}
|
||||
|
||||
@ -41,14 +41,14 @@ func (c *WrapperCache) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error {
|
||||
func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
for i, cache := range c.caches {
|
||||
err := cache.StartForward(ctx, opts)
|
||||
err := cache.StartForward(ctx, batch, reserve)
|
||||
if err != nil {
|
||||
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
|
||||
for j := i - 1; j >= 0; j-- {
|
||||
for k := range opts.Positions {
|
||||
_ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32)
|
||||
for k := range batch.Positions {
|
||||
_ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
|
||||
}
|
||||
}
|
||||
return err
|
||||
@ -87,6 +87,16 @@ func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WrapperCache) CanResume(seq int, pos int32) bool {
|
||||
for _, cache := range c.caches {
|
||||
if !cache.CanResume(seq, pos) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
// If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
|
||||
for _, cache := range c.caches {
|
||||
|
2
llama/build-info.cpp
generated
vendored
2
llama/build-info.cpp
generated
vendored
@ -1,4 +1,4 @@
|
||||
int LLAMA_BUILD_NUMBER = 0;
|
||||
char const *LLAMA_COMMIT = "d7cfe1ffe0f435d0048a6058d529daf76e072d9c";
|
||||
char const *LLAMA_COMMIT = "de4c07f93783a1a96456a44dc16b9db538ee1618";
|
||||
char const *LLAMA_COMPILER = "";
|
||||
char const *LLAMA_BUILD_TARGET = "";
|
||||
|
@ -10,10 +10,11 @@ include common/stb_image.*
|
||||
include include/
|
||||
include include/llama.*
|
||||
include include/llama-*.*
|
||||
include examples/
|
||||
include examples/llava/
|
||||
include examples/llava/clip.*
|
||||
include examples/llava/llava.*
|
||||
include tools/
|
||||
include tools/mtmd/
|
||||
include tools/mtmd/clip.*
|
||||
include tools/mtmd/clip-impl.*
|
||||
include tools/mtmd/llava.*
|
||||
include src/
|
||||
include src/llama.*
|
||||
include src/llama-*.*
|
||||
|
562
llama/llama.cpp/common/common.cpp
vendored
562
llama/llama.cpp/common/common.cpp
vendored
@ -7,10 +7,6 @@
|
||||
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include "json.hpp"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <algorithm>
|
||||
@ -52,47 +48,11 @@
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
#if defined(LLAMA_USE_CURL)
|
||||
#include <curl/curl.h>
|
||||
#include <curl/easy.h>
|
||||
#include <future>
|
||||
#endif
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
#if defined(LLAMA_USE_CURL)
|
||||
#ifdef __linux__
|
||||
#include <linux/limits.h>
|
||||
#elif defined(_WIN32)
|
||||
# if !defined(PATH_MAX)
|
||||
# define PATH_MAX MAX_PATH
|
||||
# endif
|
||||
#else
|
||||
#include <sys/syslimits.h>
|
||||
#endif
|
||||
#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
|
||||
|
||||
//
|
||||
// CURL utils
|
||||
//
|
||||
|
||||
using curl_ptr = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
|
||||
|
||||
// cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one
|
||||
struct curl_slist_ptr {
|
||||
struct curl_slist * ptr = nullptr;
|
||||
~curl_slist_ptr() {
|
||||
if (ptr) {
|
||||
curl_slist_free_all(ptr);
|
||||
}
|
||||
}
|
||||
};
|
||||
#endif // LLAMA_USE_CURL
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
//
|
||||
// CPU utils
|
||||
//
|
||||
@ -483,6 +443,11 @@ void string_replace_all(std::string & s, const std::string & search, const std::
|
||||
s = std::move(builder);
|
||||
}
|
||||
|
||||
std::string regex_escape(const std::string & s) {
|
||||
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
||||
return std::regex_replace(s, special_chars, "\\$0");
|
||||
}
|
||||
|
||||
std::string string_join(const std::vector<std::string> & values, const std::string & separator) {
|
||||
std::ostringstream result;
|
||||
for (size_t i = 0; i < values.size(); ++i) {
|
||||
@ -865,7 +830,7 @@ std::string fs_get_cache_directory() {
|
||||
if (getenv("LLAMA_CACHE")) {
|
||||
cache_directory = std::getenv("LLAMA_CACHE");
|
||||
} else {
|
||||
#ifdef __linux__
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX)
|
||||
if (std::getenv("XDG_CACHE_HOME")) {
|
||||
cache_directory = std::getenv("XDG_CACHE_HOME");
|
||||
} else {
|
||||
@ -875,7 +840,9 @@ std::string fs_get_cache_directory() {
|
||||
cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
|
||||
#elif defined(_WIN32)
|
||||
cache_directory = std::getenv("LOCALAPPDATA");
|
||||
#endif // __linux__
|
||||
#else
|
||||
# error Unknown architecture
|
||||
#endif
|
||||
cache_directory = ensure_trailing_slash(cache_directory);
|
||||
cache_directory += "llama.cpp";
|
||||
}
|
||||
@ -896,22 +863,14 @@ std::string fs_get_cache_file(const std::string & filename) {
|
||||
//
|
||||
// Model utils
|
||||
//
|
||||
|
||||
struct common_init_result common_init_from_params(common_params & params) {
|
||||
common_init_result iparams;
|
||||
auto mparams = common_model_params_to_llama(params);
|
||||
|
||||
llama_model * model = nullptr;
|
||||
|
||||
if (!params.hf_repo.empty() && !params.hf_file.empty()) {
|
||||
model = common_load_model_from_hf(params.hf_repo, params.hf_file, params.model, params.hf_token, mparams);
|
||||
} else if (!params.model_url.empty()) {
|
||||
model = common_load_model_from_url(params.model_url, params.model, params.hf_token, mparams);
|
||||
} else {
|
||||
model = llama_model_load_from_file(params.model.c_str(), mparams);
|
||||
}
|
||||
|
||||
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.c_str());
|
||||
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
|
||||
return iparams;
|
||||
}
|
||||
|
||||
@ -946,13 +905,13 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
|
||||
llama_context * lctx = llama_init_from_model(model, cparams);
|
||||
if (lctx == NULL) {
|
||||
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str());
|
||||
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
|
||||
llama_model_free(model);
|
||||
return iparams;
|
||||
}
|
||||
|
||||
if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) {
|
||||
LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__);
|
||||
if (params.ctx_shift && !llama_kv_self_can_shift(lctx)) {
|
||||
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
|
||||
params.ctx_shift = false;
|
||||
}
|
||||
|
||||
@ -1029,6 +988,8 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
if (params.warmup) {
|
||||
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
|
||||
|
||||
llama_set_warmup(lctx, true);
|
||||
|
||||
std::vector<llama_token> tmp;
|
||||
llama_token bos = llama_vocab_bos(vocab);
|
||||
llama_token eos = llama_vocab_eos(vocab);
|
||||
@ -1056,9 +1017,10 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
if (llama_model_has_decoder(model)) {
|
||||
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
|
||||
}
|
||||
llama_kv_cache_clear(lctx);
|
||||
llama_kv_self_clear(lctx);
|
||||
llama_synchronize(lctx);
|
||||
llama_perf_context_reset(lctx);
|
||||
llama_set_warmup(lctx, false);
|
||||
}
|
||||
|
||||
iparams.model.reset(model);
|
||||
@ -1067,6 +1029,19 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
return iparams;
|
||||
}
|
||||
|
||||
std::string get_model_endpoint() {
|
||||
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
|
||||
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
|
||||
const char * hf_endpoint_env = getenv("HF_ENDPOINT");
|
||||
const char * endpoint_env = model_endpoint_env ? model_endpoint_env : hf_endpoint_env;
|
||||
std::string model_endpoint = "https://huggingface.co/";
|
||||
if (endpoint_env) {
|
||||
model_endpoint = endpoint_env;
|
||||
if (model_endpoint.back() != '/') model_endpoint += '/';
|
||||
}
|
||||
return model_endpoint;
|
||||
}
|
||||
|
||||
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
|
||||
llama_clear_adapter_lora(ctx);
|
||||
for (auto & la : lora) {
|
||||
@ -1082,15 +1057,18 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
|
||||
if (!params.devices.empty()) {
|
||||
mparams.devices = params.devices.data();
|
||||
}
|
||||
|
||||
if (params.n_gpu_layers != -1) {
|
||||
mparams.n_gpu_layers = params.n_gpu_layers;
|
||||
}
|
||||
|
||||
mparams.main_gpu = params.main_gpu;
|
||||
mparams.split_mode = params.split_mode;
|
||||
mparams.tensor_split = params.tensor_split;
|
||||
mparams.use_mmap = params.use_mmap;
|
||||
mparams.use_mlock = params.use_mlock;
|
||||
mparams.check_tensors = params.check_tensors;
|
||||
|
||||
if (params.kv_overrides.empty()) {
|
||||
mparams.kv_overrides = NULL;
|
||||
} else {
|
||||
@ -1098,6 +1076,13 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
|
||||
mparams.kv_overrides = params.kv_overrides.data();
|
||||
}
|
||||
|
||||
if (params.tensor_buft_overrides.empty()) {
|
||||
mparams.tensor_buft_overrides = NULL;
|
||||
} else {
|
||||
GGML_ASSERT(params.tensor_buft_overrides.back().pattern == nullptr && "Tensor buffer overrides not terminated with empty pattern");
|
||||
mparams.tensor_buft_overrides = params.tensor_buft_overrides.data();
|
||||
}
|
||||
|
||||
return mparams;
|
||||
}
|
||||
|
||||
@ -1111,7 +1096,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
||||
cparams.n_threads = params.cpuparams.n_threads;
|
||||
cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
|
||||
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
|
||||
cparams.logits_all = params.logits_all;
|
||||
cparams.embeddings = params.embedding;
|
||||
cparams.rope_scaling_type = params.rope_scaling_type;
|
||||
cparams.rope_freq_base = params.rope_freq_base;
|
||||
@ -1129,6 +1113,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
||||
cparams.offload_kqv = !params.no_kv_offload;
|
||||
cparams.flash_attn = params.flash_attn;
|
||||
cparams.no_perf = params.no_perf;
|
||||
cparams.op_offload = !params.no_op_offload;
|
||||
|
||||
if (params.reranking) {
|
||||
cparams.embeddings = true;
|
||||
@ -1157,451 +1142,6 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p
|
||||
return tpp;
|
||||
}
|
||||
|
||||
#ifdef LLAMA_USE_CURL
|
||||
|
||||
#define CURL_MAX_RETRY 3
|
||||
#define CURL_RETRY_DELAY_SECONDS 2
|
||||
|
||||
static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds) {
|
||||
int remaining_attempts = max_attempts;
|
||||
|
||||
while (remaining_attempts > 0) {
|
||||
LOG_INF("%s: Trying to download from %s (attempt %d of %d)...\n", __func__ , url.c_str(), max_attempts - remaining_attempts + 1, max_attempts);
|
||||
|
||||
CURLcode res = curl_easy_perform(curl);
|
||||
if (res == CURLE_OK) {
|
||||
return true;
|
||||
}
|
||||
|
||||
int exponential_backoff_delay = std::pow(retry_delay_seconds, max_attempts - remaining_attempts) * 1000;
|
||||
LOG_WRN("%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay);
|
||||
|
||||
remaining_attempts--;
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
|
||||
}
|
||||
|
||||
LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
|
||||
// Initialize libcurl
|
||||
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
|
||||
curl_slist_ptr http_headers;
|
||||
if (!curl) {
|
||||
LOG_ERR("%s: error initializing libcurl\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool force_download = false;
|
||||
|
||||
// Set the URL, allow to follow http redirection
|
||||
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
|
||||
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
|
||||
|
||||
// Check if hf-token or bearer-token was specified
|
||||
if (!hf_token.empty()) {
|
||||
std::string auth_header = "Authorization: Bearer " + hf_token;
|
||||
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
|
||||
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
|
||||
}
|
||||
|
||||
#if defined(_WIN32)
|
||||
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
|
||||
// operating system. Currently implemented under MS-Windows.
|
||||
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
|
||||
#endif
|
||||
|
||||
// Check if the file already exists locally
|
||||
auto file_exists = std::filesystem::exists(path);
|
||||
|
||||
// If the file exists, check its JSON metadata companion file.
|
||||
std::string metadata_path = path + ".json";
|
||||
nlohmann::json metadata;
|
||||
std::string etag;
|
||||
std::string last_modified;
|
||||
|
||||
if (file_exists) {
|
||||
// Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
|
||||
std::ifstream metadata_in(metadata_path);
|
||||
if (metadata_in.good()) {
|
||||
try {
|
||||
metadata_in >> metadata;
|
||||
LOG_INF("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str());
|
||||
if (metadata.contains("url") && metadata.at("url").is_string()) {
|
||||
auto previous_url = metadata.at("url").get<std::string>();
|
||||
if (previous_url != url) {
|
||||
LOG_ERR("%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (metadata.contains("etag") && metadata.at("etag").is_string()) {
|
||||
etag = metadata.at("etag");
|
||||
}
|
||||
if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
|
||||
last_modified = metadata.at("lastModified");
|
||||
}
|
||||
} catch (const nlohmann::json::exception & e) {
|
||||
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
|
||||
}
|
||||
|
||||
// Send a HEAD request to retrieve the etag and last-modified headers
|
||||
struct common_load_model_from_url_headers {
|
||||
std::string etag;
|
||||
std::string last_modified;
|
||||
};
|
||||
|
||||
common_load_model_from_url_headers headers;
|
||||
|
||||
{
|
||||
typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
|
||||
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
|
||||
common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
|
||||
|
||||
static std::regex header_regex("([^:]+): (.*)\r\n");
|
||||
static std::regex etag_regex("ETag", std::regex_constants::icase);
|
||||
static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
|
||||
|
||||
std::string header(buffer, n_items);
|
||||
std::smatch match;
|
||||
if (std::regex_match(header, match, header_regex)) {
|
||||
const std::string & key = match[1];
|
||||
const std::string & value = match[2];
|
||||
if (std::regex_match(key, match, etag_regex)) {
|
||||
headers->etag = value;
|
||||
} else if (std::regex_match(key, match, last_modified_regex)) {
|
||||
headers->last_modified = value;
|
||||
}
|
||||
}
|
||||
return n_items;
|
||||
};
|
||||
|
||||
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
|
||||
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress
|
||||
curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
|
||||
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
|
||||
|
||||
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
|
||||
if (!was_perform_successful) {
|
||||
return false;
|
||||
}
|
||||
|
||||
long http_code = 0;
|
||||
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
|
||||
if (http_code != 200) {
|
||||
// HEAD not supported, we don't know if the file has changed
|
||||
// force trigger downloading
|
||||
force_download = true;
|
||||
LOG_ERR("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
|
||||
}
|
||||
}
|
||||
|
||||
bool should_download = !file_exists || force_download;
|
||||
if (!should_download) {
|
||||
if (!etag.empty() && etag != headers.etag) {
|
||||
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str());
|
||||
should_download = true;
|
||||
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
|
||||
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str());
|
||||
should_download = true;
|
||||
}
|
||||
}
|
||||
if (should_download) {
|
||||
std::string path_temporary = path + ".downloadInProgress";
|
||||
if (file_exists) {
|
||||
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
|
||||
if (remove(path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Set the output file
|
||||
|
||||
struct FILE_deleter {
|
||||
void operator()(FILE * f) const {
|
||||
fclose(f);
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "wb"));
|
||||
if (!outfile) {
|
||||
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
|
||||
auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
|
||||
return fwrite(data, size, nmemb, (FILE *)fd);
|
||||
};
|
||||
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L);
|
||||
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
|
||||
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get());
|
||||
|
||||
// display download progress
|
||||
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L);
|
||||
|
||||
// helper function to hide password in URL
|
||||
auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
|
||||
std::size_t protocol_pos = url.find("://");
|
||||
if (protocol_pos == std::string::npos) {
|
||||
return url; // Malformed URL
|
||||
}
|
||||
|
||||
std::size_t at_pos = url.find('@', protocol_pos + 3);
|
||||
if (at_pos == std::string::npos) {
|
||||
return url; // No password in URL
|
||||
}
|
||||
|
||||
return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos);
|
||||
};
|
||||
|
||||
// start the download
|
||||
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
|
||||
llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
|
||||
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
|
||||
if (!was_perform_successful) {
|
||||
return false;
|
||||
}
|
||||
|
||||
long http_code = 0;
|
||||
curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
|
||||
if (http_code < 200 || http_code >= 400) {
|
||||
LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Causes file to be closed explicitly here before we rename it.
|
||||
outfile.reset();
|
||||
|
||||
// Write the updated JSON metadata file.
|
||||
metadata.update({
|
||||
{"url", url},
|
||||
{"etag", headers.etag},
|
||||
{"lastModified", headers.last_modified}
|
||||
});
|
||||
std::ofstream(metadata_path) << metadata.dump(4);
|
||||
LOG_INF("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
|
||||
|
||||
if (rename(path_temporary.c_str(), path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
struct llama_model * common_load_model_from_url(
|
||||
const std::string & model_url,
|
||||
const std::string & local_path,
|
||||
const std::string & hf_token,
|
||||
const struct llama_model_params & params) {
|
||||
// Basic validation of the model_url
|
||||
if (model_url.empty()) {
|
||||
LOG_ERR("%s: invalid model_url\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (!common_download_file(model_url, local_path, hf_token)) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// check for additional GGUFs split to download
|
||||
int n_split = 0;
|
||||
{
|
||||
struct gguf_init_params gguf_params = {
|
||||
/*.no_alloc = */ true,
|
||||
/*.ctx = */ NULL,
|
||||
};
|
||||
auto * ctx_gguf = gguf_init_from_file(local_path.c_str(), gguf_params);
|
||||
if (!ctx_gguf) {
|
||||
LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, local_path.c_str());
|
||||
return NULL;
|
||||
}
|
||||
|
||||
auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT);
|
||||
if (key_n_split >= 0) {
|
||||
n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
|
||||
}
|
||||
|
||||
gguf_free(ctx_gguf);
|
||||
}
|
||||
|
||||
if (n_split > 1) {
|
||||
char split_prefix[PATH_MAX] = {0};
|
||||
char split_url_prefix[LLAMA_CURL_MAX_URL_LENGTH] = {0};
|
||||
|
||||
// Verify the first split file format
|
||||
// and extract split URL and PATH prefixes
|
||||
{
|
||||
if (!llama_split_prefix(split_prefix, sizeof(split_prefix), local_path.c_str(), 0, n_split)) {
|
||||
LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, local_path.c_str(), n_split);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model_url.c_str(), 0, n_split)) {
|
||||
LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model_url.c_str(), n_split);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare download in parallel
|
||||
std::vector<std::future<bool>> futures_download;
|
||||
for (int idx = 1; idx < n_split; idx++) {
|
||||
futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split, hf_token](int download_idx) -> bool {
|
||||
char split_path[PATH_MAX] = {0};
|
||||
llama_split_path(split_path, sizeof(split_path), split_prefix, download_idx, n_split);
|
||||
|
||||
char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0};
|
||||
llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split);
|
||||
|
||||
return common_download_file(split_url, split_path, hf_token);
|
||||
}, idx));
|
||||
}
|
||||
|
||||
// Wait for all downloads to complete
|
||||
for (auto & f : futures_download) {
|
||||
if (!f.get()) {
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return llama_model_load_from_file(local_path.c_str(), params);
|
||||
}
|
||||
|
||||
struct llama_model * common_load_model_from_hf(
|
||||
const std::string & repo,
|
||||
const std::string & remote_path,
|
||||
const std::string & local_path,
|
||||
const std::string & hf_token,
|
||||
const struct llama_model_params & params) {
|
||||
// construct hugging face model url:
|
||||
//
|
||||
// --repo ggml-org/models --file tinyllama-1.1b/ggml-model-f16.gguf
|
||||
// https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf
|
||||
//
|
||||
// --repo TheBloke/Mixtral-8x7B-v0.1-GGUF --file mixtral-8x7b-v0.1.Q4_K_M.gguf
|
||||
// https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q4_K_M.gguf
|
||||
//
|
||||
|
||||
std::string model_url = "https://huggingface.co/";
|
||||
model_url += repo;
|
||||
model_url += "/resolve/main/";
|
||||
model_url += remote_path;
|
||||
|
||||
return common_load_model_from_url(model_url, local_path, hf_token, params);
|
||||
}
|
||||
|
||||
/**
|
||||
* Allow getting the HF file from the HF repo with tag (like ollama), for example:
|
||||
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
|
||||
* - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M
|
||||
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s
|
||||
* Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo)
|
||||
*
|
||||
* Return pair of <repo, file> (with "repo" already having tag removed)
|
||||
*
|
||||
* Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
|
||||
*/
|
||||
std::pair<std::string, std::string> common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & hf_token) {
|
||||
auto parts = string_split<std::string>(hf_repo_with_tag, ':');
|
||||
std::string tag = parts.size() > 1 ? parts.back() : "latest";
|
||||
std::string hf_repo = parts[0];
|
||||
if (string_split<std::string>(hf_repo, '/').size() != 2) {
|
||||
throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
|
||||
}
|
||||
|
||||
// fetch model info from Hugging Face Hub API
|
||||
json model_info;
|
||||
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
|
||||
curl_slist_ptr http_headers;
|
||||
std::string res_str;
|
||||
std::string url = "https://huggingface.co/v2/" + hf_repo + "/manifests/" + tag;
|
||||
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
|
||||
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
|
||||
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
|
||||
auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
|
||||
static_cast<std::string *>(data)->append((char * ) ptr, size * nmemb);
|
||||
return size * nmemb;
|
||||
};
|
||||
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
|
||||
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str);
|
||||
#if defined(_WIN32)
|
||||
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
|
||||
#endif
|
||||
if (!hf_token.empty()) {
|
||||
std::string auth_header = "Authorization: Bearer " + hf_token;
|
||||
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
|
||||
}
|
||||
// Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
|
||||
http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
|
||||
http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json");
|
||||
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
|
||||
|
||||
CURLcode res = curl_easy_perform(curl.get());
|
||||
|
||||
if (res != CURLE_OK) {
|
||||
throw std::runtime_error("error: cannot make GET request to HF API");
|
||||
}
|
||||
|
||||
long res_code;
|
||||
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code);
|
||||
if (res_code == 200) {
|
||||
model_info = json::parse(res_str);
|
||||
} else if (res_code == 401) {
|
||||
throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
|
||||
} else {
|
||||
throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str()));
|
||||
}
|
||||
|
||||
// check response
|
||||
if (!model_info.contains("ggufFile")) {
|
||||
throw std::runtime_error("error: model does not have ggufFile");
|
||||
}
|
||||
json & gguf_file = model_info.at("ggufFile");
|
||||
if (!gguf_file.contains("rfilename")) {
|
||||
throw std::runtime_error("error: ggufFile does not have rfilename");
|
||||
}
|
||||
|
||||
return std::make_pair(hf_repo, gguf_file.at("rfilename"));
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
struct llama_model * common_load_model_from_url(
|
||||
const std::string & /*model_url*/,
|
||||
const std::string & /*local_path*/,
|
||||
const std::string & /*hf_token*/,
|
||||
const struct llama_model_params & /*params*/) {
|
||||
LOG_WRN("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
struct llama_model * common_load_model_from_hf(
|
||||
const std::string & /*repo*/,
|
||||
const std::string & /*remote_path*/,
|
||||
const std::string & /*local_path*/,
|
||||
const std::string & /*hf_token*/,
|
||||
const struct llama_model_params & /*params*/) {
|
||||
LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::pair<std::string, std::string> common_get_hf_file(const std::string &, const std::string &) {
|
||||
LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
|
||||
return std::make_pair("", "");
|
||||
}
|
||||
|
||||
#endif // LLAMA_USE_CURL
|
||||
|
||||
//
|
||||
// Batch utils
|
||||
//
|
||||
@ -2026,3 +1566,19 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
|
||||
return result;
|
||||
}
|
||||
|
||||
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride) {
|
||||
const int64_t ne_datapoint = llama_n_ctx(ctx);
|
||||
const int64_t ndata = (tokens.size() - ne_datapoint - 1) / stride;
|
||||
ggml_opt_dataset_t result = ggml_opt_dataset_init(
|
||||
GGML_TYPE_I32, GGML_TYPE_I32, ne_datapoint, ne_datapoint, ndata, /*ndata_shard =*/ 1);
|
||||
|
||||
llama_token * data = (llama_token *) ggml_opt_dataset_data(result)->data;
|
||||
llama_token * labels = (llama_token *) ggml_opt_dataset_labels(result)->data;
|
||||
|
||||
for (int64_t idata = 0; idata < ndata; ++idata) {
|
||||
memcpy(data + idata*ne_datapoint, tokens.data() + idata*stride + 0, ne_datapoint*sizeof(llama_token));
|
||||
memcpy(labels + idata*ne_datapoint, tokens.data() + idata*stride + 1, ne_datapoint*sizeof(llama_token));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
95
llama/llama.cpp/common/common.h
vendored
95
llama/llama.cpp/common/common.h
vendored
@ -66,7 +66,6 @@ enum llama_example {
|
||||
LLAMA_EXAMPLE_COMMON,
|
||||
LLAMA_EXAMPLE_SPECULATIVE,
|
||||
LLAMA_EXAMPLE_MAIN,
|
||||
LLAMA_EXAMPLE_INFILL,
|
||||
LLAMA_EXAMPLE_EMBEDDING,
|
||||
LLAMA_EXAMPLE_PERPLEXITY,
|
||||
LLAMA_EXAMPLE_RETRIEVAL,
|
||||
@ -96,6 +95,7 @@ enum common_sampler_type {
|
||||
COMMON_SAMPLER_TYPE_XTC = 8,
|
||||
COMMON_SAMPLER_TYPE_INFILL = 9,
|
||||
COMMON_SAMPLER_TYPE_PENALTIES = 10,
|
||||
COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11,
|
||||
};
|
||||
|
||||
// dimensionality reduction methods, used by cvector-generator
|
||||
@ -110,9 +110,17 @@ enum common_conversation_mode {
|
||||
COMMON_CONVERSATION_MODE_AUTO = 2,
|
||||
};
|
||||
|
||||
enum common_grammar_trigger_type {
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
||||
};
|
||||
|
||||
struct common_grammar_trigger {
|
||||
std::string word;
|
||||
bool at_start;
|
||||
common_grammar_trigger_type type;
|
||||
std::string value;
|
||||
llama_token token = LLAMA_TOKEN_NULL;
|
||||
};
|
||||
|
||||
// sampling parameters
|
||||
@ -153,6 +161,7 @@ struct common_params_sampling {
|
||||
std::vector<enum common_sampler_type> samplers = {
|
||||
COMMON_SAMPLER_TYPE_PENALTIES,
|
||||
COMMON_SAMPLER_TYPE_DRY,
|
||||
COMMON_SAMPLER_TYPE_TOP_N_SIGMA,
|
||||
COMMON_SAMPLER_TYPE_TOP_K,
|
||||
COMMON_SAMPLER_TYPE_TYPICAL_P,
|
||||
COMMON_SAMPLER_TYPE_TOP_P,
|
||||
@ -163,8 +172,7 @@ struct common_params_sampling {
|
||||
|
||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||
bool grammar_lazy = false;
|
||||
std::vector<common_grammar_trigger> grammar_trigger_words; // optional trigger words to trigger lazy grammar
|
||||
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens.
|
||||
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
|
||||
std::set<llama_token> preserved_tokens;
|
||||
|
||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
@ -173,6 +181,13 @@ struct common_params_sampling {
|
||||
std::string print() const;
|
||||
};
|
||||
|
||||
struct common_params_model {
|
||||
std::string path = ""; // model local path // NOLINT
|
||||
std::string url = ""; // model url to download // NOLINT
|
||||
std::string hf_repo = ""; // HF repo // NOLINT
|
||||
std::string hf_file = ""; // HF file // NOLINT
|
||||
};
|
||||
|
||||
struct common_params_speculative {
|
||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||
|
||||
@ -186,19 +201,13 @@ struct common_params_speculative {
|
||||
struct cpu_params cpuparams;
|
||||
struct cpu_params cpuparams_batch;
|
||||
|
||||
std::string hf_repo = ""; // HF repo // NOLINT
|
||||
std::string hf_file = ""; // HF file // NOLINT
|
||||
|
||||
std::string model = ""; // draft model for speculative decoding // NOLINT
|
||||
std::string model_url = ""; // model url to download // NOLINT
|
||||
struct common_params_model model;
|
||||
};
|
||||
|
||||
struct common_params_vocoder {
|
||||
std::string hf_repo = ""; // HF repo // NOLINT
|
||||
std::string hf_file = ""; // HF file // NOLINT
|
||||
struct common_params_model model;
|
||||
|
||||
std::string model = ""; // model path // NOLINT
|
||||
std::string model_url = ""; // model url to download // NOLINT
|
||||
std::string speaker_file = ""; // speaker file path // NOLINT
|
||||
|
||||
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
|
||||
};
|
||||
@ -254,13 +263,12 @@ struct common_params {
|
||||
struct common_params_speculative speculative;
|
||||
struct common_params_vocoder vocoder;
|
||||
|
||||
std::string model = ""; // model path // NOLINT
|
||||
struct common_params_model model;
|
||||
|
||||
std::string model_alias = ""; // model alias // NOLINT
|
||||
std::string model_url = ""; // model url to download // NOLINT
|
||||
std::string hf_token = ""; // HF token // NOLINT
|
||||
std::string hf_repo = ""; // HF repo // NOLINT
|
||||
std::string hf_file = ""; // HF file // NOLINT
|
||||
std::string prompt = ""; // NOLINT
|
||||
std::string system_prompt = ""; // NOLINT
|
||||
std::string prompt_file = ""; // store the external prompt file name // NOLINT
|
||||
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT
|
||||
std::string input_prefix = ""; // string to prefix user inputs with // NOLINT
|
||||
@ -272,6 +280,7 @@ struct common_params {
|
||||
std::vector<std::string> in_files; // all input files
|
||||
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
|
||||
std::vector<llama_model_kv_override> kv_overrides;
|
||||
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
||||
|
||||
bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_adapter_lora_apply)
|
||||
std::vector<common_adapter_lora_info> lora_adapters; // lora adapter path with user defined scale
|
||||
@ -315,7 +324,6 @@ struct common_params {
|
||||
bool ctx_shift = true; // context shift on inifinite text generation
|
||||
|
||||
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
||||
bool logits_all = false; // return logits for all tokens in the batch
|
||||
bool use_mmap = true; // use mmap for faster loads
|
||||
bool use_mlock = false; // use mlock to keep model in memory
|
||||
bool verbose_prompt = false; // print prompt tokens before generation
|
||||
@ -324,14 +332,19 @@ struct common_params {
|
||||
bool no_kv_offload = false; // disable KV offloading
|
||||
bool warmup = true; // warmup run
|
||||
bool check_tensors = false; // validate tensor data
|
||||
bool no_op_offload = false; // globally disable offload host tensor operations to device
|
||||
|
||||
bool single_turn = false; // single turn chat conversation
|
||||
|
||||
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
|
||||
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
|
||||
|
||||
common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO;
|
||||
|
||||
// multimodal models (see examples/llava)
|
||||
std::string mmproj = ""; // path to multimodal projector // NOLINT
|
||||
// multimodal models (see tools/mtmd)
|
||||
struct common_params_model mmproj;
|
||||
bool mmproj_use_gpu = true; // use GPU for multimodal model
|
||||
bool no_mmproj = false; // explicitly disable multimodal model
|
||||
std::vector<std::string> image; // path to image file(s)
|
||||
|
||||
// embedding
|
||||
@ -391,29 +404,28 @@ struct common_params {
|
||||
int32_t i_pos = -1; // position of the passkey in the junk text
|
||||
|
||||
// imatrix params
|
||||
std::string out_file = "imatrix.dat"; // save the resulting imatrix to this file
|
||||
|
||||
int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations
|
||||
int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations
|
||||
int32_t i_chunk = 0; // start processing from this chunk
|
||||
|
||||
bool process_output = false; // collect data for the output tensor
|
||||
bool compute_ppl = true; // whether to compute perplexity
|
||||
bool parse_special = false; // whether to parse special tokens during imatrix tokenization
|
||||
|
||||
// cvector-generator params
|
||||
int n_pca_batch = 100;
|
||||
int n_pca_iterations = 1000;
|
||||
dimre_method cvector_dimre_method = DIMRE_METHOD_PCA;
|
||||
std::string cvector_outfile = "control_vector.gguf";
|
||||
std::string cvector_positive_file = "examples/cvector-generator/positive.txt";
|
||||
std::string cvector_negative_file = "examples/cvector-generator/negative.txt";
|
||||
std::string cvector_positive_file = "tools/cvector-generator/positive.txt";
|
||||
std::string cvector_negative_file = "tools/cvector-generator/negative.txt";
|
||||
|
||||
bool spm_infill = false; // suffix/prefix/middle pattern for infill
|
||||
|
||||
std::string lora_outfile = "ggml-lora-merged-f16.gguf";
|
||||
|
||||
// batched-bench params
|
||||
bool batched_bench_output_jsonl = false;
|
||||
|
||||
// common params
|
||||
std::string out_file; // output filename for all example programs
|
||||
};
|
||||
|
||||
// call once at the start of a program if it uses libcommon
|
||||
@ -453,6 +465,8 @@ std::string string_repeat(const std::string & str, size_t n);
|
||||
|
||||
void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
|
||||
|
||||
std::string regex_escape(const std::string & s);
|
||||
|
||||
template<class T>
|
||||
static std::vector<T> string_split(const std::string & str, char delim) {
|
||||
static_assert(!std::is_same<T, std::string>::value, "Please use the specialized version for std::string");
|
||||
@ -530,26 +544,11 @@ struct llama_model_params common_model_params_to_llama ( common_params
|
||||
struct llama_context_params common_context_params_to_llama(const common_params & params);
|
||||
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);
|
||||
|
||||
struct llama_model * common_load_model_from_url(
|
||||
const std::string & model_url,
|
||||
const std::string & local_path,
|
||||
const std::string & hf_token,
|
||||
const struct llama_model_params & params);
|
||||
|
||||
struct llama_model * common_load_model_from_hf(
|
||||
const std::string & repo,
|
||||
const std::string & remote_path,
|
||||
const std::string & local_path,
|
||||
const std::string & hf_token,
|
||||
const struct llama_model_params & params);
|
||||
|
||||
std::pair<std::string, std::string> common_get_hf_file(
|
||||
const std::string & hf_repo_with_tag,
|
||||
const std::string & hf_token);
|
||||
|
||||
// clear LoRA adapters from context, then apply new list of adapters
|
||||
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
|
||||
|
||||
std::string get_model_endpoint();
|
||||
|
||||
//
|
||||
// Batch utils
|
||||
//
|
||||
@ -667,3 +666,9 @@ const char * const LLM_KV_SPLIT_COUNT = "split.count";
|
||||
const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
|
||||
|
||||
}
|
||||
|
||||
//
|
||||
// training utils
|
||||
//
|
||||
|
||||
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);
|
||||
|
@ -16,6 +16,9 @@ using json = nlohmann::ordered_json;
|
||||
static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
|
||||
auto has_max = max_items != std::numeric_limits<int>::max();
|
||||
|
||||
if (max_items == 0) {
|
||||
return "";
|
||||
}
|
||||
if (min_items == 0 && max_items == 1) {
|
||||
return item_rule + "?";
|
||||
}
|
||||
@ -264,7 +267,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
||||
throw std::runtime_error("At least one of min_value or max_value must be set");
|
||||
}
|
||||
|
||||
const std::string SPACE_RULE = "| \" \" | \"\\n\" [ \\t]{0,20}";
|
||||
const std::string SPACE_RULE = "| \" \" | \"\\n\"{1,2} [ \\t]{0,20}";
|
||||
|
||||
struct BuiltinRule {
|
||||
std::string content;
|
||||
@ -764,11 +767,10 @@ private:
|
||||
public:
|
||||
SchemaConverter(
|
||||
const std::function<json(const std::string &)> & fetch_json,
|
||||
bool dotall,
|
||||
bool compact_spaces)
|
||||
bool dotall)
|
||||
: _fetch_json(fetch_json), _dotall(dotall)
|
||||
{
|
||||
_rules["space"] = compact_spaces ? "\" \"?" : SPACE_RULE;
|
||||
_rules["space"] = SPACE_RULE;
|
||||
}
|
||||
|
||||
void resolve_refs(json & schema, const std::string & url) {
|
||||
@ -1007,7 +1009,7 @@ std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
|
||||
}
|
||||
|
||||
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
|
||||
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall, options.compact_spaces);
|
||||
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall);
|
||||
common_grammar_builder builder {
|
||||
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
|
||||
return converter._add_rule(name, rule);
|
||||
|
@ -16,7 +16,6 @@ struct common_grammar_builder {
|
||||
|
||||
struct common_grammar_options {
|
||||
bool dotall = false;
|
||||
bool compact_spaces = false;
|
||||
};
|
||||
|
||||
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});
|
||||
|
162
llama/llama.cpp/common/sampling.cpp
vendored
162
llama/llama.cpp/common/sampling.cpp
vendored
@ -1,9 +1,11 @@
|
||||
#include "sampling.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <unordered_map>
|
||||
#include <algorithm>
|
||||
|
||||
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
||||
// TODO: deduplicate with llama-impl.h
|
||||
@ -159,17 +161,57 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
} else {
|
||||
std::vector<const char *> trigger_words;
|
||||
trigger_words.reserve(params.grammar_trigger_words.size());
|
||||
for (const auto & str : params.grammar_trigger_words) {
|
||||
trigger_words.push_back(str.word.c_str());
|
||||
std::vector<std::string> patterns_at_start;
|
||||
std::vector<std::string> patterns_anywhere;
|
||||
std::vector<llama_token> trigger_tokens;
|
||||
for (const auto & trigger : params.grammar_triggers) {
|
||||
switch (trigger.type) {
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
|
||||
{
|
||||
const auto & word = trigger.value;
|
||||
patterns_anywhere.push_back(regex_escape(word));
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
|
||||
{
|
||||
const auto & pattern = trigger.value;
|
||||
(trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
|
||||
{
|
||||
const auto token = trigger.token;
|
||||
trigger_tokens.push_back(token);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown trigger type");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> trigger_patterns;
|
||||
if (!patterns_at_start.empty()) {
|
||||
trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*");
|
||||
}
|
||||
if (!patterns_anywhere.empty()) {
|
||||
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
|
||||
}
|
||||
|
||||
std::vector<const char *> trigger_patterns_c;
|
||||
trigger_patterns_c.reserve(trigger_patterns.size());
|
||||
for (const auto & regex : trigger_patterns) {
|
||||
trigger_patterns_c.push_back(regex.c_str());
|
||||
}
|
||||
|
||||
grmr = params.grammar_lazy
|
||||
? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root",
|
||||
trigger_words.data(), trigger_words.size(),
|
||||
params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size())
|
||||
? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
||||
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||
trigger_tokens.data(), trigger_tokens.size())
|
||||
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
||||
if (!grmr) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
auto * result = new common_sampler {
|
||||
@ -188,51 +230,48 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
params.logit_bias.data()));
|
||||
|
||||
if (params.mirostat == 0) {
|
||||
if (params.top_n_sigma >= 0) {
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp (params.temp));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
|
||||
} else {
|
||||
for (const auto & cnstr : params.samplers) {
|
||||
switch (cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_DRY:
|
||||
{
|
||||
std::vector<const char *> c_breakers;
|
||||
c_breakers.reserve(params.dry_sequence_breakers.size());
|
||||
for (const auto & str : params.dry_sequence_breakers) {
|
||||
c_breakers.push_back(str.c_str());
|
||||
}
|
||||
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
for (const auto & cnstr : params.samplers) {
|
||||
switch (cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_DRY:
|
||||
{
|
||||
std::vector<const char *> c_breakers;
|
||||
c_breakers.reserve(params.dry_sequence_breakers.size());
|
||||
for (const auto & str : params.dry_sequence_breakers) {
|
||||
c_breakers.push_back(str.c_str());
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_XTC:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_INFILL:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown sampler type");
|
||||
}
|
||||
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_XTC:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_INFILL:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown sampler type");
|
||||
}
|
||||
}
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
||||
@ -434,6 +473,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
|
||||
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
|
||||
case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
|
||||
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's';
|
||||
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
|
||||
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
||||
@ -449,6 +489,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
|
||||
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
|
||||
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
|
||||
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma";
|
||||
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
|
||||
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
||||
@ -463,6 +504,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
||||
{ "dry", COMMON_SAMPLER_TYPE_DRY },
|
||||
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
|
||||
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
|
||||
{ "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
|
||||
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
|
||||
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||
@ -476,6 +518,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
||||
std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
|
||||
{ "top-k", COMMON_SAMPLER_TYPE_TOP_K },
|
||||
{ "top-p", COMMON_SAMPLER_TYPE_TOP_P },
|
||||
{ "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
|
||||
{ "nucleus", COMMON_SAMPLER_TYPE_TOP_P },
|
||||
{ "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||
{ "typical", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||
@ -492,14 +535,16 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
||||
auto sampler = sampler_canonical_name_map.find(name);
|
||||
if (sampler != sampler_canonical_name_map.end()) {
|
||||
samplers.push_back(sampler->second);
|
||||
} else {
|
||||
if (allow_alt_names) {
|
||||
sampler = sampler_alt_name_map.find(name);
|
||||
if (sampler != sampler_alt_name_map.end()) {
|
||||
samplers.push_back(sampler->second);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (allow_alt_names) {
|
||||
sampler = sampler_alt_name_map.find(name);
|
||||
if (sampler != sampler_alt_name_map.end()) {
|
||||
samplers.push_back(sampler->second);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
LOG_WRN("%s: unable to match sampler by name '%s'\n", __func__, name.c_str());
|
||||
}
|
||||
|
||||
return samplers;
|
||||
@ -511,6 +556,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
|
||||
@ -525,6 +571,8 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
|
||||
const auto sampler = sampler_name_map.find(c);
|
||||
if (sampler != sampler_name_map.end()) {
|
||||
samplers.push_back(sampler->second);
|
||||
} else {
|
||||
LOG_WRN("%s: unable to match sampler by char '%c'\n", __func__, c);
|
||||
}
|
||||
}
|
||||
|
||||
|
3032
llama/llama.cpp/examples/llava/clip.cpp
vendored
3032
llama/llama.cpp/examples/llava/clip.cpp
vendored
File diff suppressed because it is too large
Load Diff
240
llama/llama.cpp/include/llama.h
vendored
240
llama/llama.cpp/include/llama.h
vendored
@ -4,6 +4,7 @@
|
||||
#include "ggml.h"
|
||||
#include "ggml-cpu.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml-opt.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
@ -60,6 +61,7 @@ extern "C" {
|
||||
struct llama_model;
|
||||
struct llama_context;
|
||||
struct llama_sampler;
|
||||
struct llama_kv_cache;
|
||||
|
||||
typedef int32_t llama_pos;
|
||||
typedef int32_t llama_token;
|
||||
@ -106,6 +108,12 @@ extern "C" {
|
||||
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
|
||||
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
|
||||
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
|
||||
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
|
||||
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
|
||||
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
|
||||
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
|
||||
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
|
||||
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
|
||||
};
|
||||
|
||||
enum llama_rope_type {
|
||||
@ -250,7 +258,6 @@ extern "C" {
|
||||
|
||||
llama_token * token;
|
||||
float * embd;
|
||||
int32_t n_embd;
|
||||
llama_pos * pos;
|
||||
int32_t * n_seq_id;
|
||||
llama_seq_id ** seq_id;
|
||||
@ -277,10 +284,18 @@ extern "C" {
|
||||
};
|
||||
};
|
||||
|
||||
struct llama_model_tensor_buft_override {
|
||||
const char * pattern;
|
||||
ggml_backend_buffer_type_t buft;
|
||||
};
|
||||
|
||||
struct llama_model_params {
|
||||
// NULL-terminated list of devices to use for offloading (if NULL, all available devices are used)
|
||||
ggml_backend_dev_t * devices;
|
||||
|
||||
// NULL-terminated list of buffer types to use for tensors that match a pattern
|
||||
const struct llama_model_tensor_buft_override * tensor_buft_overrides;
|
||||
|
||||
int32_t n_gpu_layers; // number of layers to store in VRAM
|
||||
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
|
||||
|
||||
@ -338,35 +353,34 @@ extern "C" {
|
||||
enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
|
||||
enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
|
||||
|
||||
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
|
||||
// TODO: move at the end of the struct
|
||||
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
||||
bool embeddings; // if true, extract embeddings (together with logits)
|
||||
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
||||
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
|
||||
bool no_perf; // whether to measure performance timings
|
||||
bool cross_attn; // whether to use cross attention
|
||||
|
||||
// Abort callback
|
||||
// if it returns true, execution of llama_decode() will be aborted
|
||||
// currently works only with CPU execution
|
||||
ggml_abort_callback abort_callback;
|
||||
void * abort_callback_data;
|
||||
|
||||
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
|
||||
bool embeddings; // if true, extract embeddings (together with logits)
|
||||
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
||||
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
|
||||
bool no_perf; // whether to measure performance timings
|
||||
bool op_offload; // whether to offload host tensor operations to device
|
||||
};
|
||||
|
||||
// model quantization parameters
|
||||
typedef struct llama_model_quantize_params {
|
||||
int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
|
||||
enum llama_ftype ftype; // quantize to this llama_ftype
|
||||
enum ggml_type output_tensor_type; // output tensor type
|
||||
enum ggml_type token_embedding_type; // token embeddings tensor type
|
||||
bool allow_requantize; // allow quantizing non-f32/f16 tensors
|
||||
bool quantize_output_tensor; // quantize output.weight
|
||||
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
|
||||
bool pure; // quantize all tensors to the default type
|
||||
bool keep_split; // quantize to the same number of shards
|
||||
void * imatrix; // pointer to importance matrix data
|
||||
void * kv_overrides; // pointer to vector containing overrides
|
||||
int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
|
||||
enum llama_ftype ftype; // quantize to this llama_ftype
|
||||
enum ggml_type output_tensor_type; // output tensor type
|
||||
enum ggml_type token_embedding_type; // token embeddings tensor type
|
||||
bool allow_requantize; // allow quantizing non-f32/f16 tensors
|
||||
bool quantize_output_tensor; // quantize output.weight
|
||||
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
|
||||
bool pure; // quantize all tensors to the default type
|
||||
bool keep_split; // quantize to the same number of shards
|
||||
void * imatrix; // pointer to importance matrix data
|
||||
void * kv_overrides; // pointer to vector containing overrides
|
||||
void * tensor_types; // pointer to vector containing tensor types
|
||||
} llama_model_quantize_params;
|
||||
|
||||
typedef struct llama_logit_bias {
|
||||
@ -432,6 +446,10 @@ extern "C" {
|
||||
size_t n_paths,
|
||||
struct llama_model_params params);
|
||||
|
||||
LLAMA_API void llama_model_save_to_file(
|
||||
const struct llama_model * model,
|
||||
const char * path_model);
|
||||
|
||||
DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model),
|
||||
"use llama_model_free instead");
|
||||
|
||||
@ -446,10 +464,6 @@ extern "C" {
|
||||
struct llama_context_params params),
|
||||
"use llama_init_from_model instead");
|
||||
|
||||
// TODO (jmorganca): this should most likely be passed in as part of a batch
|
||||
// and not set on the context for all batches.
|
||||
LLAMA_API void llama_set_cross_attention(struct llama_context * ctx, bool cross_attn_state);
|
||||
|
||||
// Frees all allocated memory
|
||||
LLAMA_API void llama_free(struct llama_context * ctx);
|
||||
|
||||
@ -475,7 +489,8 @@ extern "C" {
|
||||
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
|
||||
|
||||
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
|
||||
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
|
||||
LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx);
|
||||
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
|
||||
|
||||
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
|
||||
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
|
||||
@ -592,7 +607,7 @@ extern "C" {
|
||||
// KV cache
|
||||
//
|
||||
|
||||
// TODO: remove llama_kv_cache_view_* API
|
||||
// TODO: start using struct llama_kv_cache
|
||||
|
||||
// Information associated with an individual cell in the KV cache view.
|
||||
struct llama_kv_cache_view_cell {
|
||||
@ -647,13 +662,19 @@ extern "C" {
|
||||
|
||||
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
||||
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
||||
LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
|
||||
LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx);
|
||||
|
||||
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx),
|
||||
"use llama_kv_self_n_tokens instead");
|
||||
|
||||
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
||||
LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
|
||||
LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx);
|
||||
|
||||
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx),
|
||||
"use llama_kv_self_used_cells instead");
|
||||
|
||||
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
||||
LLAMA_API void llama_kv_cache_clear(
|
||||
LLAMA_API void llama_kv_self_clear(
|
||||
struct llama_context * ctx);
|
||||
|
||||
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
@ -661,7 +682,7 @@ extern "C" {
|
||||
// seq_id < 0 : match any sequence
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API bool llama_kv_cache_seq_rm(
|
||||
LLAMA_API bool llama_kv_self_seq_rm(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
@ -671,7 +692,7 @@ extern "C" {
|
||||
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_kv_cache_seq_cp(
|
||||
LLAMA_API void llama_kv_self_seq_cp(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
@ -679,17 +700,17 @@ extern "C" {
|
||||
llama_pos p1);
|
||||
|
||||
// Removes all tokens that do not belong to the specified sequence
|
||||
LLAMA_API void llama_kv_cache_seq_keep(
|
||||
LLAMA_API void llama_kv_self_seq_keep(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||
// - lazily on next llama_decode()
|
||||
// - explicitly with llama_kv_cache_update()
|
||||
// - explicitly with llama_kv_self_update()
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_kv_cache_seq_add(
|
||||
LLAMA_API void llama_kv_self_seq_add(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
@ -699,10 +720,10 @@ extern "C" {
|
||||
// Integer division of the positions by factor of `d > 1`
|
||||
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||
// - lazily on next llama_decode()
|
||||
// - explicitly with llama_kv_cache_update()
|
||||
// - explicitly with llama_kv_self_update()
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_kv_cache_seq_div(
|
||||
LLAMA_API void llama_kv_self_seq_div(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
@ -710,24 +731,76 @@ extern "C" {
|
||||
int d);
|
||||
|
||||
// Returns the largest position present in the KV cache for the specified sequence
|
||||
LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
|
||||
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// TODO: the llama_kv_cache_defrag and llama_kv_cache_update API tightly couples llama_context with llama_kv_cache
|
||||
// how to avoid this?
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Defragment the KV cache
|
||||
// This will be applied:
|
||||
// - lazily on next llama_decode()
|
||||
// - explicitly with llama_kv_cache_update()
|
||||
LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx);
|
||||
|
||||
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
||||
LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
|
||||
// - explicitly with llama_kv_self_update()
|
||||
LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx);
|
||||
|
||||
// Check if the context supports KV cache shifting
|
||||
LLAMA_API bool llama_kv_cache_can_shift(struct llama_context * ctx);
|
||||
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
|
||||
|
||||
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
||||
LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
|
||||
|
||||
DEPRECATED(LLAMA_API void llama_kv_cache_clear(
|
||||
struct llama_context * ctx),
|
||||
"use llama_kv_self_clear instead");
|
||||
|
||||
DEPRECATED(LLAMA_API bool llama_kv_cache_seq_rm(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1),
|
||||
"use llama_kv_self_seq_rm instead");
|
||||
|
||||
DEPRECATED(LLAMA_API void llama_kv_cache_seq_cp(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1),
|
||||
"use llama_kv_self_seq_cp instead");
|
||||
|
||||
DEPRECATED(LLAMA_API void llama_kv_cache_seq_keep(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id),
|
||||
"use llama_kv_self_seq_keep instead");
|
||||
|
||||
DEPRECATED(LLAMA_API void llama_kv_cache_seq_add(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta),
|
||||
"use llama_kv_self_seq_add instead");
|
||||
|
||||
DEPRECATED(LLAMA_API void llama_kv_cache_seq_div(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d),
|
||||
"use llama_kv_self_seq_div instead");
|
||||
|
||||
DEPRECATED(LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id),
|
||||
"use llama_kv_self_seq_pos_max instead");
|
||||
|
||||
DEPRECATED(LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx),
|
||||
"use llama_kv_self_defrag instead");
|
||||
|
||||
DEPRECATED(LLAMA_API bool llama_kv_cache_can_shift(const struct llama_context * ctx),
|
||||
"use llama_kv_self_can_shift instead");
|
||||
|
||||
DEPRECATED(LLAMA_API void llama_kv_cache_update(struct llama_context * ctx),
|
||||
"use llama_kv_self_update instead");
|
||||
|
||||
|
||||
//
|
||||
// State / sessions
|
||||
@ -856,14 +929,19 @@ extern "C" {
|
||||
// Frees a batch of tokens allocated with llama_batch_init()
|
||||
LLAMA_API void llama_batch_free(struct llama_batch batch);
|
||||
|
||||
// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
|
||||
// Stores the encoder output internally for later use by the decoder cross-attention layers.
|
||||
// Process a batch of tokens.
|
||||
// In contrast to llama_decode() - this call does not use KV cache.
|
||||
// For encode-decoder contexts, processes the batch using the encoder.
|
||||
// Can store the encoder output internally for later use by the decoder's cross-attention layers.
|
||||
// 0 - success
|
||||
// < 0 - error. the KV cache state is restored to the state before this call
|
||||
LLAMA_API int32_t llama_encode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch);
|
||||
|
||||
// Process a batch of tokens.
|
||||
// Requires KV cache.
|
||||
// For encode-decoder contexts, processes the batch using the decoder.
|
||||
// Positive return values does not mean a fatal error, but rather a warning.
|
||||
// 0 - success
|
||||
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
||||
@ -891,6 +969,10 @@ extern "C" {
|
||||
// If set to true, the model will only attend to the past tokens
|
||||
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
|
||||
|
||||
// Set whether the model is in warmup mode or not
|
||||
// If true, all model tensors are activated during llama_decode() to load and cache their weights.
|
||||
LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup);
|
||||
|
||||
// Set abort callback
|
||||
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
|
||||
|
||||
@ -1160,6 +1242,7 @@ extern "C" {
|
||||
"will be removed in the future (see https://github.com/ggml-org/llama.cpp/pull/9896#discussion_r1800920915)");
|
||||
|
||||
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
/// Setting k <= 0 makes this a noop
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
|
||||
|
||||
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
@ -1206,22 +1289,38 @@ extern "C" {
|
||||
float tau,
|
||||
float eta);
|
||||
|
||||
/// @details Intializes a GBNF grammar, see grammars/README.md for details.
|
||||
/// @param vocab The vocabulary that this grammar will be used with.
|
||||
/// @param grammar_str The production rules for the grammar, encoded as a string. Returns an empty grammar if empty. Returns NULL if parsing of grammar_str fails.
|
||||
/// @param grammar_root The name of the start symbol for the grammar.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
|
||||
const struct llama_vocab * vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root);
|
||||
|
||||
/// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639
|
||||
/// @param trigger_words A list of words that will trigger the grammar sampler. This may be updated to a loose regex syntax (w/ ^) in a near future.
|
||||
/// @param trigger_tokens A list of tokens that will trigger the grammar sampler.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy(
|
||||
DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy(
|
||||
const struct llama_vocab * vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root,
|
||||
const char ** trigger_words,
|
||||
size_t num_trigger_words,
|
||||
const llama_token * trigger_tokens,
|
||||
size_t num_trigger_tokens);
|
||||
size_t num_trigger_tokens),
|
||||
"use llama_sampler_init_grammar_lazy_patterns instead");
|
||||
|
||||
|
||||
/// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639
|
||||
/// @param trigger_patterns A list of patterns that will trigger the grammar sampler. Pattern will be matched from the start of the generation output, and grammar sampler will be fed content starting from its first match group.
|
||||
/// @param trigger_tokens A list of tokens that will trigger the grammar sampler. Grammar sampler will be fed content starting from the trigger token included.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy_patterns(
|
||||
const struct llama_vocab * vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root,
|
||||
const char ** trigger_patterns,
|
||||
size_t num_trigger_patterns,
|
||||
const llama_token * trigger_tokens,
|
||||
size_t num_trigger_tokens);
|
||||
|
||||
|
||||
/// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
|
||||
@ -1339,6 +1438,37 @@ extern "C" {
|
||||
LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain);
|
||||
LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain);
|
||||
|
||||
//
|
||||
// training
|
||||
//
|
||||
|
||||
// function that returns whether or not a given tensor contains trainable parameters
|
||||
typedef bool (*llama_opt_param_filter)(const struct ggml_tensor * tensor, void * userdata);
|
||||
|
||||
// always returns true
|
||||
LLAMA_API bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata);
|
||||
|
||||
struct llama_opt_params {
|
||||
uint32_t n_ctx_train; // assumed context size post training, use context size specified in llama_context if 0
|
||||
|
||||
llama_opt_param_filter param_filter; // callback for determining which tensors contain trainable parameters
|
||||
void * param_filter_ud; // userdata for determining which tensors contain trainable parameters
|
||||
|
||||
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
|
||||
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
|
||||
};
|
||||
|
||||
LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params);
|
||||
|
||||
LLAMA_API void llama_opt_epoch(
|
||||
struct llama_context * lctx,
|
||||
ggml_opt_dataset_t dataset,
|
||||
ggml_opt_result_t result_train,
|
||||
ggml_opt_result_t result_eval,
|
||||
int64_t idata_split,
|
||||
ggml_opt_epoch_callback callback_train,
|
||||
ggml_opt_epoch_callback callback_eval);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
81
llama/llama.cpp/src/llama-adapter.cpp
vendored
81
llama/llama.cpp/src/llama-adapter.cpp
vendored
@ -4,14 +4,13 @@
|
||||
#include "llama-mmap.h"
|
||||
#include "llama-model.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <cassert>
|
||||
#include <stdexcept>
|
||||
|
||||
// vec
|
||||
|
||||
struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
|
||||
ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
|
||||
if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -19,7 +18,7 @@ struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
|
||||
return tensors[il];
|
||||
}
|
||||
|
||||
struct ggml_tensor * llama_adapter_cvec::apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const {
|
||||
ggml_tensor * llama_adapter_cvec::apply_to(ggml_context * ctx, ggml_tensor * cur, int il) const {
|
||||
ggml_tensor * layer_dir = tensor_for(il);
|
||||
if (layer_dir != nullptr) {
|
||||
cur = ggml_add(ctx, cur, layer_dir);
|
||||
@ -40,7 +39,7 @@ bool llama_adapter_cvec::init(const llama_model & model) {
|
||||
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
||||
auto it = ctx_map.find(buft);
|
||||
if (it == ctx_map.end()) {
|
||||
struct ggml_init_params params = {
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ hparams.n_layer*ggml_tensor_overhead(),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
@ -91,7 +90,7 @@ bool llama_adapter_cvec::init(const llama_model & model) {
|
||||
return true;
|
||||
}
|
||||
|
||||
int32_t llama_adapter_cvec::apply(
|
||||
bool llama_adapter_cvec::apply(
|
||||
const llama_model & model,
|
||||
const float * data,
|
||||
size_t len,
|
||||
@ -104,17 +103,17 @@ int32_t llama_adapter_cvec::apply(
|
||||
// disable the current control vector (but leave allocated for later)
|
||||
layer_start = -1;
|
||||
layer_end = -1;
|
||||
return 0;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (n_embd != (int) hparams.n_embd) {
|
||||
LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__);
|
||||
return 1;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tensors.empty()) {
|
||||
if (!init(model)) {
|
||||
return 1;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@ -130,12 +129,12 @@ int32_t llama_adapter_cvec::apply(
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
return true;
|
||||
}
|
||||
|
||||
// lora
|
||||
|
||||
llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor * w) {
|
||||
llama_adapter_lora_weight * llama_adapter_lora::get_weight(ggml_tensor * w) {
|
||||
const std::string name(w->name);
|
||||
|
||||
const auto pos = ab_map.find(name);
|
||||
@ -146,11 +145,11 @@ llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor *
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static void llama_adapter_lora_init_impl(struct llama_model & model, const char * path_lora, struct llama_adapter_lora & adapter) {
|
||||
static void llama_adapter_lora_init_impl(llama_model & model, const char * path_lora, llama_adapter_lora & adapter) {
|
||||
LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
|
||||
|
||||
ggml_context * ctx_init;
|
||||
struct gguf_init_params meta_gguf_params = {
|
||||
gguf_init_params meta_gguf_params = {
|
||||
/* .no_alloc = */ true,
|
||||
/* .ctx = */ &ctx_init,
|
||||
};
|
||||
@ -201,7 +200,7 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
|
||||
auto it = ctx_map.find(buft);
|
||||
if (it == ctx_map.end()) {
|
||||
// add a new context
|
||||
struct ggml_init_params params = {
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ n_tensors*ggml_tensor_overhead(),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
@ -248,6 +247,29 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
|
||||
}
|
||||
}
|
||||
|
||||
// get extra buffer types of the CPU
|
||||
// TODO: a more general solution for non-CPU extra buft should be imlpemented in the future
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/12593#pullrequestreview-2718659948
|
||||
std::vector<ggml_backend_buffer_type_t> buft_extra;
|
||||
{
|
||||
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (!cpu_dev) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
}
|
||||
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
|
||||
|
||||
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
|
||||
ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
|
||||
|
||||
if (ggml_backend_dev_get_extra_bufts_fn) {
|
||||
ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev);
|
||||
while (extra_bufts && *extra_bufts) {
|
||||
buft_extra.emplace_back(*extra_bufts);
|
||||
++extra_bufts;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add tensors
|
||||
for (auto & it : ab_map) {
|
||||
const std::string & name = it.first;
|
||||
@ -264,7 +286,26 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
|
||||
throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)");
|
||||
}
|
||||
|
||||
struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
|
||||
auto * buft = ggml_backend_buffer_get_type(model_tensor->buffer);
|
||||
|
||||
// do not load loras to extra buffer types (i.e. bufts for repacking) -> use the CPU in that case
|
||||
for (auto & ex : buft_extra) {
|
||||
if (ex == buft) {
|
||||
LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
|
||||
|
||||
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (!cpu_dev) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
}
|
||||
buft = ggml_backend_dev_buffer_type(cpu_dev);
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: lora for '%s' -> '%s'\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
|
||||
|
||||
ggml_context * dev_ctx = ctx_for_buft(buft);
|
||||
// validate tensor shape
|
||||
if (is_token_embd) {
|
||||
// expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd()
|
||||
@ -281,8 +322,8 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
|
||||
}
|
||||
|
||||
// save tensor to adapter
|
||||
struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
|
||||
struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
|
||||
ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
|
||||
ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
|
||||
ggml_set_name(tensor_a, w.a->name);
|
||||
ggml_set_name(tensor_b, w.b->name);
|
||||
adapter.ab_map[name] = llama_adapter_lora_weight(tensor_a, tensor_b);
|
||||
@ -308,7 +349,7 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
|
||||
{
|
||||
llama_file gguf_file(path_lora, "rb");
|
||||
std::vector<uint8_t> read_buf;
|
||||
auto set_tensor = [&](struct ggml_tensor * orig, struct ggml_tensor * dev) {
|
||||
auto set_tensor = [&](ggml_tensor * orig, ggml_tensor * dev) {
|
||||
size_t offs = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), gguf_find_tensor(ctx_gguf.get(), orig->name));
|
||||
size_t size = ggml_nbytes(orig);
|
||||
read_buf.resize(size);
|
||||
@ -327,8 +368,8 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
|
||||
LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
|
||||
}
|
||||
|
||||
struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model, const char * path_lora) {
|
||||
struct llama_adapter_lora * adapter = new llama_adapter_lora();
|
||||
llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * path_lora) {
|
||||
llama_adapter_lora * adapter = new llama_adapter_lora();
|
||||
|
||||
try {
|
||||
llama_adapter_lora_init_impl(*model, path_lora, *adapter);
|
||||
@ -342,6 +383,6 @@ struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void llama_adapter_lora_free(struct llama_adapter_lora * adapter) {
|
||||
void llama_adapter_lora_free(llama_adapter_lora * adapter) {
|
||||
delete adapter;
|
||||
}
|
||||
|
20
llama/llama.cpp/src/llama-adapter.h
vendored
20
llama/llama.cpp/src/llama-adapter.h
vendored
@ -15,11 +15,11 @@
|
||||
//
|
||||
|
||||
struct llama_adapter_cvec {
|
||||
struct ggml_tensor * tensor_for(int il) const;
|
||||
ggml_tensor * tensor_for(int il) const;
|
||||
|
||||
struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const;
|
||||
ggml_tensor * apply_to(ggml_context * ctx, ggml_tensor * cur, int il) const;
|
||||
|
||||
int32_t apply(
|
||||
bool apply(
|
||||
const llama_model & model,
|
||||
const float * data,
|
||||
size_t len,
|
||||
@ -36,7 +36,7 @@ private:
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
std::vector<struct ggml_tensor *> tensors; // per layer
|
||||
std::vector<ggml_tensor *> tensors; // per layer
|
||||
};
|
||||
|
||||
//
|
||||
@ -44,8 +44,8 @@ private:
|
||||
//
|
||||
|
||||
struct llama_adapter_lora_weight {
|
||||
struct ggml_tensor * a = nullptr;
|
||||
struct ggml_tensor * b = nullptr;
|
||||
ggml_tensor * a = nullptr;
|
||||
ggml_tensor * b = nullptr;
|
||||
|
||||
// get actual scale based on rank and alpha
|
||||
float get_scale(float alpha, float adapter_scale) const {
|
||||
@ -55,12 +55,12 @@ struct llama_adapter_lora_weight {
|
||||
}
|
||||
|
||||
llama_adapter_lora_weight() = default;
|
||||
llama_adapter_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {}
|
||||
llama_adapter_lora_weight(ggml_tensor * a, ggml_tensor * b) : a(a), b(b) {}
|
||||
};
|
||||
|
||||
struct llama_adapter_lora {
|
||||
// map tensor name to lora_a_b
|
||||
std::unordered_map<std::string, struct llama_adapter_lora_weight> ab_map;
|
||||
std::unordered_map<std::string, llama_adapter_lora_weight> ab_map;
|
||||
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
@ -70,5 +70,7 @@ struct llama_adapter_lora {
|
||||
llama_adapter_lora() = default;
|
||||
~llama_adapter_lora() = default;
|
||||
|
||||
llama_adapter_lora_weight * get_weight(struct ggml_tensor * w);
|
||||
llama_adapter_lora_weight * get_weight(ggml_tensor * w);
|
||||
};
|
||||
|
||||
using llama_adapter_loras = std::unordered_map<llama_adapter_lora *, float>;
|
||||
|
294
llama/llama.cpp/src/llama-arch.cpp
vendored
294
llama/llama.cpp/src/llama-arch.cpp
vendored
@ -6,7 +6,7 @@
|
||||
|
||||
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_LLAMA, "llama" },
|
||||
{ LLM_ARCH_MLLAMA, "mllama" },
|
||||
{ LLM_ARCH_LLAMA4, "llama4" },
|
||||
{ LLM_ARCH_DECI, "deci" },
|
||||
{ LLM_ARCH_FALCON, "falcon" },
|
||||
{ LLM_ARCH_GROK, "grok" },
|
||||
@ -19,6 +19,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_REFACT, "refact" },
|
||||
{ LLM_ARCH_BERT, "bert" },
|
||||
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
|
||||
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
|
||||
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
|
||||
{ LLM_ARCH_BLOOM, "bloom" },
|
||||
{ LLM_ARCH_STABLELM, "stablelm" },
|
||||
@ -26,6 +27,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_QWEN2, "qwen2" },
|
||||
{ LLM_ARCH_QWEN2MOE, "qwen2moe" },
|
||||
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
|
||||
{ LLM_ARCH_QWEN3, "qwen3" },
|
||||
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
|
||||
{ LLM_ARCH_PHI2, "phi2" },
|
||||
{ LLM_ARCH_PHI3, "phi3" },
|
||||
{ LLM_ARCH_PHIMOE, "phimoe" },
|
||||
@ -52,6 +55,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_DEEPSEEK, "deepseek" },
|
||||
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
|
||||
{ LLM_ARCH_CHATGLM, "chatglm" },
|
||||
{ LLM_ARCH_GLM4, "glm4" },
|
||||
{ LLM_ARCH_BITNET, "bitnet" },
|
||||
{ LLM_ARCH_T5, "t5" },
|
||||
{ LLM_ARCH_T5ENCODER, "t5encoder" },
|
||||
@ -60,11 +64,15 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_EXAONE, "exaone" },
|
||||
{ LLM_ARCH_RWKV6, "rwkv6" },
|
||||
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
|
||||
{ LLM_ARCH_RWKV7, "rwkv7" },
|
||||
{ LLM_ARCH_ARWKV7, "arwkv7" },
|
||||
{ LLM_ARCH_GRANITE, "granite" },
|
||||
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
|
||||
{ LLM_ARCH_CHAMELEON, "chameleon" },
|
||||
{ LLM_ARCH_SOLAR, "solar" },
|
||||
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
|
||||
{ LLM_ARCH_PLM, "plm" },
|
||||
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
@ -73,6 +81,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" },
|
||||
{ LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" },
|
||||
{ LLM_KV_GENERAL_ALIGNMENT, "general.alignment" },
|
||||
{ LLM_KV_GENERAL_FILE_TYPE, "general.file_type" },
|
||||
{ LLM_KV_GENERAL_NAME, "general.name" },
|
||||
{ LLM_KV_GENERAL_AUTHOR, "general.author" },
|
||||
{ LLM_KV_GENERAL_VERSION, "general.version" },
|
||||
@ -99,6 +108,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
|
||||
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
|
||||
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
|
||||
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
|
||||
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
|
||||
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
|
||||
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
|
||||
@ -111,25 +121,31 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
|
||||
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
|
||||
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
|
||||
{ LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" },
|
||||
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
||||
{ LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" },
|
||||
{ LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" },
|
||||
{ LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" },
|
||||
{ LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
|
||||
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
|
||||
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
|
||||
{ LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" },
|
||||
{ LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" },
|
||||
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
|
||||
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
|
||||
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
||||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||
{ LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" },
|
||||
{ LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, "%s.attention.cross_attention_layers" },
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
||||
{ LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" },
|
||||
{ LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" },
|
||||
{ LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" },
|
||||
{ LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
|
||||
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
|
||||
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
|
||||
{ LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" },
|
||||
{ LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" },
|
||||
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
|
||||
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_DECAY_LORA_RANK, "%s.attention.decay_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_ICLR_LORA_RANK, "%s.attention.iclr_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, "%s.attention.value_residual_mix_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_GATE_LORA_RANK, "%s.attention.gate_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
|
||||
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
||||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||
{ LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" },
|
||||
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
||||
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
||||
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
||||
@ -229,7 +245,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_MLLAMA,
|
||||
LLM_ARCH_LLAMA4,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
@ -252,14 +268,9 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
{ LLM_TENSOR_CROSS_ATTN_K_NORM, "blk.%d.cross_attn_k_norm" },
|
||||
{ LLM_TENSOR_CROSS_ATTN_K_PROJ, "blk.%d.cross_attn_k_proj" },
|
||||
{ LLM_TENSOR_CROSS_ATTN_O_PROJ, "blk.%d.cross_attn_o_proj" },
|
||||
{ LLM_TENSOR_CROSS_ATTN_Q_NORM, "blk.%d.cross_attn_q_norm" },
|
||||
{ LLM_TENSOR_CROSS_ATTN_Q_PROJ, "blk.%d.cross_attn_q_proj" },
|
||||
{ LLM_TENSOR_CROSS_ATTN_V_PROJ, "blk.%d.cross_attn_v_proj" },
|
||||
{ LLM_TENSOR_CROSS_ATTN_ATTN_GATE, "blk.%d.cross_attn_attn_gate" },
|
||||
{ LLM_TENSOR_CROSS_ATTN_MLP_GATE, "blk.%d.cross_attn_mlp_gate" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -465,6 +476,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_NOMIC_BERT_MOE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
||||
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
|
||||
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
|
||||
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_JINA_BERT_V2,
|
||||
{
|
||||
@ -593,6 +622,45 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_QWEN3,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_QWEN3MOE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_PHI2,
|
||||
{
|
||||
@ -810,9 +878,12 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
@ -1056,6 +1127,8 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
|
||||
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
|
||||
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
|
||||
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
|
||||
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
@ -1072,6 +1145,22 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_PLM,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
|
||||
{ LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" },
|
||||
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_CHATGLM,
|
||||
{
|
||||
@ -1090,6 +1179,25 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_GLM4,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_BITNET,
|
||||
{
|
||||
@ -1274,6 +1382,74 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_RWKV7,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
|
||||
{ LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
|
||||
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
|
||||
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
|
||||
{ LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
|
||||
{ LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
|
||||
{ LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
|
||||
{ LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
|
||||
{ LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
|
||||
{ LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
|
||||
{ LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
|
||||
{ LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
|
||||
{ LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
|
||||
{ LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
|
||||
{ LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
|
||||
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
|
||||
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
|
||||
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
|
||||
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
|
||||
{ LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
|
||||
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
|
||||
{ LLM_TENSOR_CHANNEL_MIX_LERP_K, "blk.%d.channel_mix_lerp_k" },
|
||||
{ LLM_TENSOR_CHANNEL_MIX_KEY, "blk.%d.channel_mix_key" },
|
||||
{ LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_ARWKV7,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
|
||||
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
|
||||
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
|
||||
{ LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
|
||||
{ LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
|
||||
{ LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
|
||||
{ LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
|
||||
{ LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
|
||||
{ LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
|
||||
{ LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
|
||||
{ LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
|
||||
{ LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
|
||||
{ LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
|
||||
{ LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
|
||||
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
|
||||
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
|
||||
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
|
||||
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
|
||||
{ LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
|
||||
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_GRANITE,
|
||||
{
|
||||
@ -1371,6 +1547,29 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_BAILINGMOE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
@ -1408,23 +1607,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
@ -1451,6 +1635,12 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_A1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_A2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_V1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_V2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_G1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_G2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_DECAY_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_DECAY_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_TIME_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
@ -1469,6 +1659,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_CHANNEL_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_TIME_MIX_K_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_TIME_MIX_K_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_TIME_MIX_R_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_TIME_MIX_LERP_W, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
@ -1476,6 +1669,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_W0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_A0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_V0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
|
||||
{LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
@ -1505,14 +1701,6 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
// this tensor is loaded for T5, but never used
|
||||
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
|
||||
{LLM_TENSOR_BSKCN_TV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_CROSS_ATTN_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_CROSS_ATTN_K_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CROSS_ATTN_O_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CROSS_ATTN_Q_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_CROSS_ATTN_Q_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CROSS_ATTN_V_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CROSS_ATTN_ATTN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_CROSS_ATTN_MLP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}},
|
||||
{LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
|
44
llama/llama.cpp/src/llama-arch.h
vendored
44
llama/llama.cpp/src/llama-arch.h
vendored
@ -10,7 +10,7 @@
|
||||
|
||||
enum llm_arch {
|
||||
LLM_ARCH_LLAMA,
|
||||
LLM_ARCH_MLLAMA,
|
||||
LLM_ARCH_LLAMA4,
|
||||
LLM_ARCH_DECI,
|
||||
LLM_ARCH_FALCON,
|
||||
LLM_ARCH_BAICHUAN,
|
||||
@ -23,6 +23,7 @@ enum llm_arch {
|
||||
LLM_ARCH_REFACT,
|
||||
LLM_ARCH_BERT,
|
||||
LLM_ARCH_NOMIC_BERT,
|
||||
LLM_ARCH_NOMIC_BERT_MOE,
|
||||
LLM_ARCH_JINA_BERT_V2,
|
||||
LLM_ARCH_BLOOM,
|
||||
LLM_ARCH_STABLELM,
|
||||
@ -30,6 +31,8 @@ enum llm_arch {
|
||||
LLM_ARCH_QWEN2,
|
||||
LLM_ARCH_QWEN2MOE,
|
||||
LLM_ARCH_QWEN2VL,
|
||||
LLM_ARCH_QWEN3,
|
||||
LLM_ARCH_QWEN3MOE,
|
||||
LLM_ARCH_PHI2,
|
||||
LLM_ARCH_PHI3,
|
||||
LLM_ARCH_PHIMOE,
|
||||
@ -56,6 +59,7 @@ enum llm_arch {
|
||||
LLM_ARCH_DEEPSEEK,
|
||||
LLM_ARCH_DEEPSEEK2,
|
||||
LLM_ARCH_CHATGLM,
|
||||
LLM_ARCH_GLM4,
|
||||
LLM_ARCH_BITNET,
|
||||
LLM_ARCH_T5,
|
||||
LLM_ARCH_T5ENCODER,
|
||||
@ -64,11 +68,15 @@ enum llm_arch {
|
||||
LLM_ARCH_EXAONE,
|
||||
LLM_ARCH_RWKV6,
|
||||
LLM_ARCH_RWKV6QWEN2,
|
||||
LLM_ARCH_RWKV7,
|
||||
LLM_ARCH_ARWKV7,
|
||||
LLM_ARCH_GRANITE,
|
||||
LLM_ARCH_GRANITE_MOE,
|
||||
LLM_ARCH_CHAMELEON,
|
||||
LLM_ARCH_SOLAR,
|
||||
LLM_ARCH_WAVTOKENIZER_DEC,
|
||||
LLM_ARCH_PLM,
|
||||
LLM_ARCH_BAILINGMOE,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
@ -77,6 +85,7 @@ enum llm_kv {
|
||||
LLM_KV_GENERAL_ARCHITECTURE,
|
||||
LLM_KV_GENERAL_QUANTIZATION_VERSION,
|
||||
LLM_KV_GENERAL_ALIGNMENT,
|
||||
LLM_KV_GENERAL_FILE_TYPE,
|
||||
LLM_KV_GENERAL_NAME,
|
||||
LLM_KV_GENERAL_AUTHOR,
|
||||
LLM_KV_GENERAL_VERSION,
|
||||
@ -103,6 +112,7 @@ enum llm_kv {
|
||||
LLM_KV_EXPERT_WEIGHTS_SCALE,
|
||||
LLM_KV_EXPERT_WEIGHTS_NORM,
|
||||
LLM_KV_EXPERT_GATING_FUNC,
|
||||
LLM_KV_MOE_EVERY_N_LAYERS,
|
||||
LLM_KV_POOLING_TYPE,
|
||||
LLM_KV_LOGIT_SCALE,
|
||||
LLM_KV_DECODER_START_TOKEN_ID,
|
||||
@ -115,6 +125,7 @@ enum llm_kv {
|
||||
LLM_KV_RESIDUAL_SCALE,
|
||||
LLM_KV_EMBEDDING_SCALE,
|
||||
LLM_KV_TOKEN_SHIFT_COUNT,
|
||||
LLM_KV_INTERLEAVE_MOE_LAYER_STEP,
|
||||
|
||||
LLM_KV_ATTENTION_HEAD_COUNT,
|
||||
LLM_KV_ATTENTION_HEAD_COUNT_KV,
|
||||
@ -129,11 +140,16 @@ enum llm_kv {
|
||||
LLM_KV_ATTENTION_CAUSAL,
|
||||
LLM_KV_ATTENTION_Q_LORA_RANK,
|
||||
LLM_KV_ATTENTION_KV_LORA_RANK,
|
||||
LLM_KV_ATTENTION_DECAY_LORA_RANK,
|
||||
LLM_KV_ATTENTION_ICLR_LORA_RANK,
|
||||
LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK,
|
||||
LLM_KV_ATTENTION_GATE_LORA_RANK,
|
||||
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
|
||||
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
||||
LLM_KV_ATTENTION_SCALE,
|
||||
LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
|
||||
LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS,
|
||||
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
||||
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
||||
|
||||
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
||||
@ -247,6 +263,8 @@ enum llm_tensor {
|
||||
LLM_TENSOR_ATTN_Q_NORM,
|
||||
LLM_TENSOR_ATTN_K_NORM,
|
||||
LLM_TENSOR_LAYER_OUT_NORM,
|
||||
LLM_TENSOR_POST_ATTN_NORM,
|
||||
LLM_TENSOR_POST_MLP_NORM,
|
||||
LLM_TENSOR_SSM_IN,
|
||||
LLM_TENSOR_SSM_CONV1D,
|
||||
LLM_TENSOR_SSM_X,
|
||||
@ -254,8 +272,20 @@ enum llm_tensor {
|
||||
LLM_TENSOR_SSM_A,
|
||||
LLM_TENSOR_SSM_D,
|
||||
LLM_TENSOR_SSM_OUT,
|
||||
LLM_TENSOR_TIME_MIX_W0,
|
||||
LLM_TENSOR_TIME_MIX_W1,
|
||||
LLM_TENSOR_TIME_MIX_W2,
|
||||
LLM_TENSOR_TIME_MIX_A0,
|
||||
LLM_TENSOR_TIME_MIX_A1,
|
||||
LLM_TENSOR_TIME_MIX_A2,
|
||||
LLM_TENSOR_TIME_MIX_V0,
|
||||
LLM_TENSOR_TIME_MIX_V1,
|
||||
LLM_TENSOR_TIME_MIX_V2,
|
||||
LLM_TENSOR_TIME_MIX_G1,
|
||||
LLM_TENSOR_TIME_MIX_G2,
|
||||
LLM_TENSOR_TIME_MIX_K_K,
|
||||
LLM_TENSOR_TIME_MIX_K_A,
|
||||
LLM_TENSOR_TIME_MIX_R_K,
|
||||
LLM_TENSOR_TIME_MIX_LERP_X,
|
||||
LLM_TENSOR_TIME_MIX_LERP_W,
|
||||
LLM_TENSOR_TIME_MIX_LERP_K,
|
||||
@ -282,6 +312,8 @@ enum llm_tensor {
|
||||
LLM_TENSOR_ATTN_Q_B,
|
||||
LLM_TENSOR_ATTN_KV_A_MQA,
|
||||
LLM_TENSOR_ATTN_KV_B,
|
||||
LLM_TENSOR_ATTN_K_B,
|
||||
LLM_TENSOR_ATTN_V_B,
|
||||
LLM_TENSOR_ATTN_Q_A_NORM,
|
||||
LLM_TENSOR_ATTN_KV_A_NORM,
|
||||
LLM_TENSOR_ATTN_SUB_NORM,
|
||||
@ -317,14 +349,6 @@ enum llm_tensor {
|
||||
LLM_TENSOR_CLS,
|
||||
LLM_TENSOR_CLS_OUT,
|
||||
LLM_TENSOR_BSKCN_TV,
|
||||
LLM_TENSOR_CROSS_ATTN_K_NORM,
|
||||
LLM_TENSOR_CROSS_ATTN_K_PROJ,
|
||||
LLM_TENSOR_CROSS_ATTN_O_PROJ,
|
||||
LLM_TENSOR_CROSS_ATTN_Q_NORM,
|
||||
LLM_TENSOR_CROSS_ATTN_Q_PROJ,
|
||||
LLM_TENSOR_CROSS_ATTN_V_PROJ,
|
||||
LLM_TENSOR_CROSS_ATTN_ATTN_GATE,
|
||||
LLM_TENSOR_CROSS_ATTN_MLP_GATE,
|
||||
LLM_TENSOR_CONV1D,
|
||||
LLM_TENSOR_CONVNEXT_DW,
|
||||
LLM_TENSOR_CONVNEXT_NORM,
|
||||
|
9
llama/llama.cpp/src/llama-batch.cpp
vendored
9
llama/llama.cpp/src/llama-batch.cpp
vendored
@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
|
||||
return ubatch;
|
||||
}
|
||||
|
||||
void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
|
||||
llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
|
||||
GGML_ASSERT(batch.n_tokens >= 0);
|
||||
this->batch = &batch;
|
||||
this->n_embd = n_embd;
|
||||
@ -203,6 +203,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
|
||||
for (size_t i = 0; i < n_tokens; ++i) {
|
||||
ids[i] = i;
|
||||
}
|
||||
|
||||
if (simple_split) {
|
||||
seq.resize(1);
|
||||
llama_sbatch_seq & s = seq[0];
|
||||
@ -212,6 +213,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
|
||||
s.length = n_tokens;
|
||||
return;
|
||||
}
|
||||
|
||||
std::sort(ids.begin(), ids.end(),
|
||||
[&batch](size_t a, size_t b) {
|
||||
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
|
||||
@ -239,6 +241,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
|
||||
return n_seq_a > n_seq_b;
|
||||
}
|
||||
);
|
||||
|
||||
// init seq
|
||||
llama_sbatch_seq * last_seq = nullptr;
|
||||
|
||||
@ -262,6 +265,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
|
||||
seq.push_back(new_seq);
|
||||
last_seq = &seq.back();
|
||||
}
|
||||
|
||||
// keep shared prompts first at the end, then sort by length descending.
|
||||
std::sort(seq.begin(), seq.end(),
|
||||
[](llama_sbatch_seq & a, llama_sbatch_seq & b) {
|
||||
@ -316,7 +320,6 @@ struct llama_batch llama_batch_get_one(
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*tokens =*/ tokens,
|
||||
/*embd =*/ nullptr,
|
||||
/*n_embd =*/ 0,
|
||||
/*pos =*/ nullptr,
|
||||
/*n_seq_id =*/ nullptr,
|
||||
/*seq_id =*/ nullptr,
|
||||
@ -329,7 +332,6 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
|
||||
/*n_tokens =*/ 0,
|
||||
/*tokens =*/ nullptr,
|
||||
/*embd =*/ nullptr,
|
||||
/*n_embd =*/ 0,
|
||||
/*pos =*/ nullptr,
|
||||
/*n_seq_id =*/ nullptr,
|
||||
/*seq_id =*/ nullptr,
|
||||
@ -338,7 +340,6 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
|
||||
|
||||
if (embd) {
|
||||
batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
|
||||
batch.n_embd = embd;
|
||||
} else {
|
||||
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
|
||||
}
|
||||
|
7
llama/llama.cpp/src/llama-batch.h
vendored
7
llama/llama.cpp/src/llama-batch.h
vendored
@ -42,9 +42,9 @@ struct llama_sbatch {
|
||||
bool logits_all; // TODO: remove once lctx.logits_all is removed too
|
||||
|
||||
// sorted indices into the batch
|
||||
std::vector<size_t> ids;
|
||||
std::vector<int64_t> ids;
|
||||
// batch indices of the output
|
||||
std::vector<size_t> out_ids;
|
||||
std::vector<int64_t> out_ids;
|
||||
std::vector<llama_sbatch_seq> seq;
|
||||
|
||||
const llama_batch * batch = nullptr;
|
||||
@ -70,7 +70,8 @@ struct llama_sbatch {
|
||||
// sequence-wise split
|
||||
llama_ubatch split_seq(size_t n_ubatch);
|
||||
|
||||
void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
|
||||
llama_sbatch() = default;
|
||||
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
|
||||
};
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
|
108
llama/llama.cpp/src/llama-chat.cpp
vendored
108
llama/llama.cpp/src/llama-chat.cpp
vendored
@ -4,6 +4,7 @@
|
||||
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <algorithm>
|
||||
|
||||
#if __cplusplus >= 202000L
|
||||
#define LU8(x) (const char*)(u8##x)
|
||||
@ -34,6 +35,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||
{ "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 },
|
||||
{ "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN },
|
||||
{ "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 },
|
||||
{ "mistral-v7-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN },
|
||||
{ "phi3", LLM_CHAT_TEMPLATE_PHI_3 },
|
||||
{ "phi4", LLM_CHAT_TEMPLATE_PHI_4 },
|
||||
{ "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 },
|
||||
@ -49,8 +51,8 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||
{ "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 },
|
||||
{ "command-r", LLM_CHAT_TEMPLATE_COMMAND_R },
|
||||
{ "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 },
|
||||
{ "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 },
|
||||
{ "chatglm4", LLM_CHAT_TEMPLATE_CHATGML_4 },
|
||||
{ "chatglm3", LLM_CHAT_TEMPLATE_CHATGLM_3 },
|
||||
{ "chatglm4", LLM_CHAT_TEMPLATE_CHATGLM_4 },
|
||||
{ "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
|
||||
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
|
||||
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
|
||||
@ -58,6 +60,10 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
|
||||
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
|
||||
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
|
||||
{ "yandex", LLM_CHAT_TEMPLATE_YANDEX },
|
||||
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
|
||||
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
|
||||
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
|
||||
};
|
||||
|
||||
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
||||
@ -77,7 +83,9 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||
if (tmpl_contains("<|im_start|>")) {
|
||||
return tmpl_contains("<|im_sep|>")
|
||||
? LLM_CHAT_TEMPLATE_PHI_4
|
||||
: LLM_CHAT_TEMPLATE_CHATML;
|
||||
: tmpl_contains("<end_of_utterance>")
|
||||
? LLM_CHAT_TEMPLATE_SMOLVLM // SmolVLM uses <|im_start|> as BOS, but it is NOT chatml
|
||||
: LLM_CHAT_TEMPLATE_CHATML;
|
||||
} else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) {
|
||||
if (tmpl_contains("[SYSTEM_PROMPT]")) {
|
||||
return LLM_CHAT_TEMPLATE_MISTRAL_V7;
|
||||
@ -115,8 +123,12 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||
}
|
||||
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) {
|
||||
return LLM_CHAT_TEMPLATE_PHI_3;
|
||||
} else if (tmpl_contains("[gMASK]<sop>")) {
|
||||
return LLM_CHAT_TEMPLATE_CHATGLM_4;
|
||||
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
|
||||
return tmpl_contains("</s>") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE;
|
||||
} else if (tmpl_contains("<|{{ item['role'] }}|>") && tmpl_contains("<|begin_of_image|>")) {
|
||||
return LLM_CHAT_TEMPLATE_GLMEDGE;
|
||||
} else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) {
|
||||
return LLM_CHAT_TEMPLATE_ZEPHYR;
|
||||
} else if (tmpl_contains("bos_token + message['role']")) {
|
||||
@ -145,9 +157,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||
return LLM_CHAT_TEMPLATE_LLAMA_3;
|
||||
} else if (tmpl_contains("[gMASK]sop")) {
|
||||
// chatglm3-6b
|
||||
return LLM_CHAT_TEMPLATE_CHATGML_3;
|
||||
} else if (tmpl_contains("[gMASK]<sop>")) {
|
||||
return LLM_CHAT_TEMPLATE_CHATGML_4;
|
||||
return LLM_CHAT_TEMPLATE_CHATGLM_3;
|
||||
} else if (tmpl_contains(LU8("<用户>"))) {
|
||||
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
|
||||
return LLM_CHAT_TEMPLATE_MINICPM;
|
||||
@ -167,6 +177,12 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||
return LLM_CHAT_TEMPLATE_GIGACHAT;
|
||||
} else if (tmpl_contains("<|role_start|>")) {
|
||||
return LLM_CHAT_TEMPLATE_MEGREZ;
|
||||
} else if (tmpl_contains(" Ассистент:")) {
|
||||
return LLM_CHAT_TEMPLATE_YANDEX;
|
||||
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("'HUMAN'")) {
|
||||
return LLM_CHAT_TEMPLATE_BAILING;
|
||||
} else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
|
||||
return LLM_CHAT_TEMPLATE_LLAMA4;
|
||||
}
|
||||
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
||||
}
|
||||
@ -187,19 +203,20 @@ int32_t llm_chat_apply_template(
|
||||
if (add_ass) {
|
||||
ss << "<|im_start|>assistant\n";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7) {
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN) {
|
||||
// Official mistral 'v7' template
|
||||
// See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7
|
||||
// https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503#basic-instruct-template-v7-tekken
|
||||
const char * trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 ? " " : "";
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
std::string content(message->content);
|
||||
if (role == "system") {
|
||||
ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]";
|
||||
ss << "[SYSTEM_PROMPT]" << trailing_space << content << "[/SYSTEM_PROMPT]";
|
||||
} else if (role == "user") {
|
||||
ss << "[INST] " << content << "[/INST]";
|
||||
}
|
||||
else {
|
||||
ss << " " << content << "</s>";
|
||||
ss << "[INST]" << trailing_space << content << "[/INST]";
|
||||
} else {
|
||||
ss << trailing_space << content << "</s>";
|
||||
}
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1
|
||||
@ -422,7 +439,7 @@ int32_t llm_chat_apply_template(
|
||||
if (add_ass) {
|
||||
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_3) {
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_3) {
|
||||
// chatglm3-6b
|
||||
ss << "[gMASK]" << "sop";
|
||||
for (auto message : chat) {
|
||||
@ -432,14 +449,14 @@ int32_t llm_chat_apply_template(
|
||||
if (add_ass) {
|
||||
ss << "<|assistant|>";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_4) {
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_4) {
|
||||
ss << "[gMASK]" << "<sop>";
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
ss << "<|" << role << "|>" << "\n" << message->content;
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|assistant|>";
|
||||
ss << "<|assistant|>\n";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) {
|
||||
for (auto message : chat) {
|
||||
@ -566,6 +583,66 @@ int32_t llm_chat_apply_template(
|
||||
if (add_ass) {
|
||||
ss << "<|role_start|>assistant<|role_end|>";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_YANDEX) {
|
||||
// Yandex template ("\n\n" is defined as EOT token)
|
||||
|
||||
ss << "<s>";
|
||||
|
||||
for (size_t i = 0; i < chat.size(); i++) {
|
||||
std::string role(chat[i]->role);
|
||||
if (role == "user") {
|
||||
ss << " Пользователь: " << chat[i]->content << "\n\n";
|
||||
} else if (role == "assistant") {
|
||||
ss << " Ассистент: " << chat[i]->content << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Add generation prompt if needed
|
||||
if (add_ass) {
|
||||
ss << " Ассистент:[SEP]";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_BAILING) {
|
||||
// Bailing (Ling) template
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
|
||||
if (role == "user") {
|
||||
role = "HUMAN";
|
||||
} else {
|
||||
std::transform(role.begin(), role.end(), role.begin(), ::toupper);
|
||||
}
|
||||
|
||||
ss << "<role>" << role << "</role>" << message->content;
|
||||
}
|
||||
|
||||
if (add_ass) {
|
||||
ss << "<role>ASSISTANT</role>";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA4) {
|
||||
// Llama 4
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
ss << "<|header_start|>" << role << "<|header_end|>\n\n" << trim(message->content) << "<|eot|>";
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|header_start|>assistant<|header_end|>\n\n";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_SMOLVLM) {
|
||||
// SmolVLM
|
||||
ss << "<|im_start|>"; // uses <|im_start|> as BOS, but the actual content is NOT chatml
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "system") {
|
||||
ss << message->content << "\n\n";
|
||||
} else if (role == "user") {
|
||||
ss << "User: " << message->content << "<end_of_utterance>\n";
|
||||
} else {
|
||||
ss << "Assistant: " << message->content << "<end_of_utterance>\n";
|
||||
}
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "Assistant:";
|
||||
}
|
||||
} else {
|
||||
// template not supported
|
||||
return -1;
|
||||
@ -584,4 +661,3 @@ int32_t llama_chat_builtin_templates(const char ** output, size_t len) {
|
||||
}
|
||||
return (int32_t) LLM_CHAT_TEMPLATES.size();
|
||||
}
|
||||
|
||||
|
9
llama/llama.cpp/src/llama-chat.h
vendored
9
llama/llama.cpp/src/llama-chat.h
vendored
@ -14,6 +14,7 @@ enum llm_chat_template {
|
||||
LLM_CHAT_TEMPLATE_MISTRAL_V3,
|
||||
LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN,
|
||||
LLM_CHAT_TEMPLATE_MISTRAL_V7,
|
||||
LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN,
|
||||
LLM_CHAT_TEMPLATE_PHI_3,
|
||||
LLM_CHAT_TEMPLATE_PHI_4,
|
||||
LLM_CHAT_TEMPLATE_FALCON_3,
|
||||
@ -29,8 +30,8 @@ enum llm_chat_template {
|
||||
LLM_CHAT_TEMPLATE_DEEPSEEK_3,
|
||||
LLM_CHAT_TEMPLATE_COMMAND_R,
|
||||
LLM_CHAT_TEMPLATE_LLAMA_3,
|
||||
LLM_CHAT_TEMPLATE_CHATGML_3,
|
||||
LLM_CHAT_TEMPLATE_CHATGML_4,
|
||||
LLM_CHAT_TEMPLATE_CHATGLM_3,
|
||||
LLM_CHAT_TEMPLATE_CHATGLM_4,
|
||||
LLM_CHAT_TEMPLATE_GLMEDGE,
|
||||
LLM_CHAT_TEMPLATE_MINICPM,
|
||||
LLM_CHAT_TEMPLATE_EXAONE_3,
|
||||
@ -38,6 +39,10 @@ enum llm_chat_template {
|
||||
LLM_CHAT_TEMPLATE_GRANITE,
|
||||
LLM_CHAT_TEMPLATE_GIGACHAT,
|
||||
LLM_CHAT_TEMPLATE_MEGREZ,
|
||||
LLM_CHAT_TEMPLATE_YANDEX,
|
||||
LLM_CHAT_TEMPLATE_BAILING,
|
||||
LLM_CHAT_TEMPLATE_LLAMA4,
|
||||
LLM_CHAT_TEMPLATE_SMOLVLM,
|
||||
LLM_CHAT_TEMPLATE_UNKNOWN,
|
||||
};
|
||||
|
||||
|
3723
llama/llama.cpp/src/llama-context.cpp
vendored
3723
llama/llama.cpp/src/llama-context.cpp
vendored
File diff suppressed because it is too large
Load Diff
313
llama/llama.cpp/src/llama-context.h
vendored
313
llama/llama.cpp/src/llama-context.h
vendored
@ -3,66 +3,222 @@
|
||||
#include "llama.h"
|
||||
#include "llama-batch.h"
|
||||
#include "llama-cparams.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-kv-cache.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-adapter.h"
|
||||
#include "llama-kv-cache.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
#include "ggml-opt.h"
|
||||
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
||||
struct llama_model;
|
||||
struct llama_kv_cache;
|
||||
|
||||
class llama_io_read_i;
|
||||
class llama_io_write_i;
|
||||
|
||||
struct llama_context {
|
||||
llama_context(const llama_model & model)
|
||||
: model(model)
|
||||
, t_start_us(model.t_start_us)
|
||||
, t_load_us(model.t_load_us) {}
|
||||
// init scheduler and compute buffers, reserve worst-case graphs
|
||||
llama_context(
|
||||
const llama_model & model,
|
||||
llama_context_params params);
|
||||
|
||||
const struct llama_model & model;
|
||||
~llama_context();
|
||||
|
||||
struct llama_cparams cparams;
|
||||
struct llama_sbatch sbatch; // TODO: revisit if needed
|
||||
struct llama_kv_cache kv_self;
|
||||
struct llama_adapter_cvec cvec;
|
||||
void synchronize();
|
||||
|
||||
std::unordered_map<struct llama_adapter_lora *, float> lora;
|
||||
const llama_model & get_model() const;
|
||||
const llama_cparams & get_cparams() const;
|
||||
|
||||
std::vector<ggml_backend_ptr> backends;
|
||||
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
|
||||
ggml_backend_sched_t get_sched() const;
|
||||
|
||||
ggml_backend_t backend_cpu = nullptr;
|
||||
ggml_context * get_ctx_compute() const;
|
||||
|
||||
ggml_threadpool_t threadpool = nullptr;
|
||||
ggml_threadpool_t threadpool_batch = nullptr;
|
||||
uint32_t n_ctx() const;
|
||||
uint32_t n_ctx_per_seq() const;
|
||||
uint32_t n_batch() const;
|
||||
uint32_t n_ubatch() const;
|
||||
uint32_t n_seq_max() const;
|
||||
|
||||
bool has_evaluated_once = false;
|
||||
uint32_t n_threads() const;
|
||||
uint32_t n_threads_batch() const;
|
||||
|
||||
mutable int64_t t_start_us;
|
||||
mutable int64_t t_load_us;
|
||||
mutable int64_t t_p_eval_us = 0;
|
||||
mutable int64_t t_eval_us = 0;
|
||||
llama_kv_cache * get_kv_self();
|
||||
const llama_kv_cache * get_kv_self() const;
|
||||
|
||||
mutable int64_t t_compute_start_us = 0;
|
||||
mutable int64_t n_queued_tokens = 0;
|
||||
void kv_self_update();
|
||||
|
||||
mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
|
||||
mutable int32_t n_eval = 0; // number of eval calls
|
||||
enum llama_pooling_type pooling_type() const;
|
||||
|
||||
// host buffer for the model output (logits and embeddings)
|
||||
ggml_backend_buffer_ptr buf_output;
|
||||
float * get_logits();
|
||||
float * get_logits_ith(int32_t i);
|
||||
|
||||
float * get_embeddings();
|
||||
float * get_embeddings_ith(int32_t i);
|
||||
float * get_embeddings_seq(llama_seq_id seq_id);
|
||||
|
||||
void attach_threadpool(
|
||||
ggml_threadpool_t threadpool,
|
||||
ggml_threadpool_t threadpool_batch);
|
||||
|
||||
void detach_threadpool();
|
||||
|
||||
void set_n_threads(int32_t n_threads, int32_t n_threads_batch);
|
||||
|
||||
void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data);
|
||||
|
||||
void set_embeddings (bool value);
|
||||
void set_causal_attn(bool value);
|
||||
void set_warmup(bool value);
|
||||
|
||||
void set_adapter_lora(
|
||||
llama_adapter_lora * adapter,
|
||||
float scale);
|
||||
|
||||
bool rm_adapter_lora(
|
||||
llama_adapter_lora * adapter);
|
||||
|
||||
void clear_adapter_lora();
|
||||
|
||||
bool apply_adapter_cvec(
|
||||
const float * data,
|
||||
size_t len,
|
||||
int32_t n_embd,
|
||||
int32_t il_start,
|
||||
int32_t il_end);
|
||||
|
||||
int encode(llama_batch & inp_batch);
|
||||
int decode(llama_batch & inp_batch);
|
||||
|
||||
//
|
||||
// state save/load
|
||||
//
|
||||
|
||||
size_t state_get_size();
|
||||
size_t state_get_data( uint8_t * dst, size_t size);
|
||||
size_t state_set_data(const uint8_t * src, size_t size);
|
||||
|
||||
size_t state_seq_get_size(llama_seq_id seq_id);
|
||||
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size);
|
||||
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size);
|
||||
|
||||
bool state_load_file(
|
||||
const char * filepath,
|
||||
llama_token * tokens_out,
|
||||
size_t n_token_capacity,
|
||||
size_t * n_token_count_out);
|
||||
|
||||
bool state_save_file(
|
||||
const char * filepath,
|
||||
const llama_token * tokens,
|
||||
size_t n_token_count);
|
||||
|
||||
size_t state_seq_load_file(
|
||||
llama_seq_id seq_id,
|
||||
const char * filepath,
|
||||
llama_token * tokens_out,
|
||||
size_t n_token_capacity,
|
||||
size_t * n_token_count_out);
|
||||
|
||||
size_t state_seq_save_file(
|
||||
llama_seq_id seq_id,
|
||||
const char * filepath,
|
||||
const llama_token * tokens,
|
||||
size_t n_token_count);
|
||||
|
||||
//
|
||||
// perf
|
||||
//
|
||||
|
||||
llama_perf_context_data perf_get_data() const;
|
||||
void perf_reset();
|
||||
|
||||
//
|
||||
// training
|
||||
//
|
||||
|
||||
void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
|
||||
|
||||
void opt_epoch(
|
||||
ggml_opt_dataset_t dataset,
|
||||
ggml_opt_result_t result_train,
|
||||
ggml_opt_result_t result_eval,
|
||||
int64_t idata_split,
|
||||
ggml_opt_epoch_callback callback_train,
|
||||
ggml_opt_epoch_callback callback_eval);
|
||||
|
||||
void opt_epoch_iter(
|
||||
ggml_opt_dataset_t dataset,
|
||||
ggml_opt_result_t result,
|
||||
const std::vector<llama_token> & tokens,
|
||||
const std::vector<llama_token> & labels_sparse,
|
||||
llama_batch & batch,
|
||||
ggml_opt_epoch_callback callback,
|
||||
bool train,
|
||||
int64_t idata_in_loop,
|
||||
int64_t ndata_in_loop,
|
||||
int64_t t_loop_start);
|
||||
|
||||
private:
|
||||
//
|
||||
// output
|
||||
//
|
||||
|
||||
// Make sure enough space is available for outputs.
|
||||
// Returns max number of outputs for which space was reserved.
|
||||
int32_t output_reserve(int32_t n_outputs);
|
||||
|
||||
//
|
||||
// graph
|
||||
//
|
||||
|
||||
public:
|
||||
int32_t graph_max_nodes() const;
|
||||
|
||||
// zero-out inputs and create the ctx_compute for the compute graph
|
||||
ggml_cgraph * graph_init();
|
||||
|
||||
// returns the result of ggml_backend_sched_graph_compute_async execution
|
||||
ggml_status graph_compute(
|
||||
ggml_cgraph * gf,
|
||||
bool batched);
|
||||
|
||||
private:
|
||||
llm_graph_result_ptr graph_build(
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_type gtype);
|
||||
|
||||
llm_graph_cb graph_get_cb() const;
|
||||
|
||||
// TODO: read/write lora adapters and cvec
|
||||
size_t state_write_data(llama_io_write_i & io);
|
||||
size_t state_read_data (llama_io_read_i & io);
|
||||
|
||||
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id);
|
||||
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id);
|
||||
|
||||
//
|
||||
// members
|
||||
//
|
||||
|
||||
const llama_model & model;
|
||||
|
||||
llama_cparams cparams;
|
||||
llama_adapter_cvec cvec;
|
||||
llama_adapter_loras loras;
|
||||
|
||||
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
|
||||
|
||||
std::unique_ptr<llama_memory_i> memory;
|
||||
|
||||
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
||||
size_t logits_size = 0; // capacity (of floats) for logits
|
||||
float * logits = nullptr;
|
||||
|
||||
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
|
||||
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
|
||||
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
|
||||
|
||||
bool logits_all = false;
|
||||
|
||||
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
||||
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||
size_t embd_size = 0; // capacity (of floats) for embeddings
|
||||
@ -72,59 +228,50 @@ struct llama_context {
|
||||
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
|
||||
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
||||
|
||||
// whether we are computing encoder output or decoder output
|
||||
bool is_encoding = false;
|
||||
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
|
||||
int32_t n_outputs_max = 0; // capacity (of tokens positions) for the output buffers
|
||||
|
||||
// TODO: find a better way to accommodate mutli-dimension position encoding methods
|
||||
// number of position id each token get, 1 for each token in most cases.
|
||||
// when using m-rope, it will be 3 position ids per token to representing 3 dimension coordinate.
|
||||
int n_pos_per_token = 1;
|
||||
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
|
||||
|
||||
// output of the encoder part of the encoder-decoder models
|
||||
std::vector<float> embd_enc;
|
||||
std::vector<std::set<llama_seq_id>> seq_ids_enc;
|
||||
|
||||
// memory buffers used to evaluate the model
|
||||
std::vector<uint8_t> buf_compute_meta;
|
||||
ggml_backend_sched_ptr sched;
|
||||
|
||||
ggml_backend_t backend_cpu = nullptr;
|
||||
std::vector<ggml_backend_ptr> backends;
|
||||
|
||||
ggml_context_ptr ctx_compute;
|
||||
|
||||
// training
|
||||
ggml_opt_context_t opt_ctx = nullptr;
|
||||
|
||||
ggml_threadpool_t threadpool = nullptr;
|
||||
ggml_threadpool_t threadpool_batch = nullptr;
|
||||
|
||||
ggml_abort_callback abort_callback = nullptr;
|
||||
void * abort_callback_data = nullptr;
|
||||
|
||||
// input tensors
|
||||
struct ggml_tensor * inp_tokens; // I32 [n_batch]
|
||||
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
|
||||
struct ggml_tensor * inp_pos; // I32 [n_batch]
|
||||
struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
|
||||
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
|
||||
struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch]
|
||||
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
|
||||
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
|
||||
struct ggml_tensor * inp_cls; // I32 [n_batch]
|
||||
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
|
||||
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
|
||||
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
|
||||
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
|
||||
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
|
||||
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
|
||||
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
|
||||
|
||||
struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061]
|
||||
// buffer types used for the compute buffer of each backend
|
||||
std::vector<ggml_backend_t> backend_ptrs;
|
||||
std::vector<ggml_backend_buffer_type_t> backend_buft;
|
||||
|
||||
// memory buffers used to evaluate the model
|
||||
std::vector<uint8_t> buf_compute_meta;
|
||||
|
||||
// host buffer for the model output (logits and embeddings)
|
||||
ggml_backend_buffer_ptr buf_output;
|
||||
|
||||
bool has_evaluated_once = false;
|
||||
|
||||
// perf
|
||||
mutable int64_t t_start_us = 0;
|
||||
mutable int64_t t_load_us = 0;
|
||||
mutable int64_t t_p_eval_us = 0;
|
||||
mutable int64_t t_eval_us = 0;
|
||||
|
||||
mutable int64_t t_compute_start_us = 0;
|
||||
mutable int64_t n_queued_tokens = 0;
|
||||
|
||||
mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
|
||||
mutable int32_t n_eval = 0; // number of eval calls
|
||||
};
|
||||
|
||||
// TODO: make these methods of llama_context
|
||||
void llama_set_k_shift(struct llama_context & lctx);
|
||||
|
||||
void llama_set_s_copy(struct llama_context & lctx);
|
||||
|
||||
void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch);
|
||||
|
||||
// Make sure enough space is available for outputs.
|
||||
// Returns max number of outputs for which space was reserved.
|
||||
size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs);
|
||||
|
||||
// make the outputs have the same order they had in the user-provided batch
|
||||
void llama_output_reorder(struct llama_context & ctx);
|
||||
|
||||
// For internal test use
|
||||
// TODO: remove
|
||||
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(struct llama_context * ctx);
|
||||
|
3
llama/llama.cpp/src/llama-cparams.h
vendored
3
llama/llama.cpp/src/llama-cparams.h
vendored
@ -29,7 +29,8 @@ struct llama_cparams {
|
||||
bool offload_kqv;
|
||||
bool flash_attn;
|
||||
bool no_perf;
|
||||
bool cross_attn;
|
||||
bool warmup;
|
||||
bool op_offload;
|
||||
|
||||
enum llama_pooling_type pooling_type;
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user