Compare commits

..

2 Commits

Author SHA1 Message Date
ParthSareen
3bc9d42e2e rebase + fix tests 2025-04-03 17:31:21 -07:00
ParthSareen
4053c489b4 server: enable content streaming with tools 2025-04-03 17:09:59 -07:00
352 changed files with 43983 additions and 60541 deletions

View File

@ -432,22 +432,6 @@ jobs:
docker buildx imagetools inspect ollama/ollama:${{ steps.metadata.outputs.version }}
working-directory: ${{ runner.temp }}
# Trigger downstream release process
trigger:
runs-on: ubuntu-latest
environment: release
needs: [darwin-build, windows-build, windows-depends]
steps:
- name: Trigger downstream release process
run: |
curl -L \
-X POST \
-H "Accept: application/vnd.github+json" \
-H "Authorization: Bearer ${{ secrets.RELEASE_TOKEN }}" \
-H "X-GitHub-Api-Version: 2022-11-28" \
https://api.github.com/repos/ollama/${{ vars.RELEASE_REPO }}/dispatches \
-d "{\"event_type\": \"trigger-workflow\", \"client_payload\": {\"run_id\": \"${GITHUB_RUN_ID}\", \"version\": \"${GITHUB_REF_NAME#v}\"}}"
# Aggregate all the assets and ship a release
release:
needs: [darwin-sign, windows-sign, linux-build]

View File

@ -237,5 +237,5 @@ jobs:
- uses: actions/checkout@v4
- name: Verify patches apply cleanly and do not change files
run: |
make -f Makefile.sync clean checkout apply-patches sync
git diff --compact-summary --exit-code
make -f Makefile.sync clean sync
git diff --compact-summary --exit-code

View File

@ -19,8 +19,8 @@ linters:
- nolintlint
- nosprintfhostport
- staticcheck
- tenv
- unconvert
- usetesting
- wastedassign
- whitespace
disable:

View File

@ -24,7 +24,6 @@ set(GGML_LLAMAFILE ON)
set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128)
set(GGML_CUDA_GRAPHS ON)
set(GGML_CUDA_FA ON)
set(GGML_CUDA_COMPRESSION_MODE default)
if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_SYSTEM_PROCESSOR MATCHES "arm|aarch64|ARM64|ARMv[0-9]+"))

View File

@ -21,16 +21,14 @@
"name": "CUDA 11",
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;70;75;80;86",
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;70;75;80;86"
}
},
{
"name": "CUDA 12",
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120",
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120"
}
},
{

View File

@ -51,7 +51,7 @@ see if the change were accepted.
The title should look like:
<package>: <short description>
<package>: <short description>
The package is the most affected Go package. If the change does not affect Go
code, then use the directory name instead. Changes to a single well-known

View File

@ -104,8 +104,8 @@ COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
FROM --platform=linux/arm64 scratch AS arm64
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_jetpack5
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_jetpack6
COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 lib/ollama/cuda_jetpack5
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 lib/ollama/cuda_jetpack6
FROM scratch AS rocm
COPY --from=rocm-6 dist/lib/ollama/rocm /lib/ollama/rocm

View File

@ -1,6 +1,6 @@
UPSTREAM=https://github.com/ggerganov/llama.cpp.git
WORKDIR=llama/vendor
FETCH_HEAD=de4c07f93783a1a96456a44dc16b9db538ee1618
FETCH_HEAD=d7cfe1ffe0f435d0048a6058d529daf76e072d9c
.PHONY: help
help:
@ -15,30 +15,27 @@ help:
@echo " make -f $(lastword $(MAKEFILE_LIST)) clean sync"
.PHONY: sync
sync: llama/build-info.cpp ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
sync: llama/build-info.cpp llama/llama.cpp ml/backend/ggml/ggml apply-patches
llama/build-info.cpp: llama/build-info.cpp.in llama/llama.cpp
sed -e 's|@FETCH_HEAD@|$(FETCH_HEAD)|' <$< >$@
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal: ml/backend/ggml/ggml
go generate ./$(@D)
.PHONY: llama/build-info.cpp
llama/build-info.cpp: llama/build-info.cpp.in
sed -e 's|@FETCH_HEAD@|$(FETCH_HEAD)|' $< > $@
.PHONY: llama/llama.cpp
llama/llama.cpp: llama/vendor/
llama/llama.cpp: llama/vendor/ apply-patches
rsync -arvzc -f "merge $@/.rsync-filter" $< $@
.PHONY: ml/backend/ggml/ggml
ml/backend/ggml/ggml: llama/vendor/ggml/
.PHONY: ml/backend/ggml/ggml apply-patches
ml/backend/ggml/ggml: llama/vendor/ggml/ apply-patches
rsync -arvzc -f "merge $@/.rsync-filter" $< $@
PATCHES=$(wildcard llama/patches/*.patch)
PATCHED=$(join $(dir $(PATCHES)), $(addsuffix ed, $(addprefix ., $(notdir $(PATCHES)))))
.PHONY: apply-patches
.NOTPARALLEL:
apply-patches: $(PATCHED)
apply-patches: $(addsuffix ed, $(PATCHES))
llama/patches/.%.patched: llama/patches/%.patch
%.patched: %.patch
@if git -c user.name=nobody -c 'user.email=<>' -C $(WORKDIR) am -3 $(realpath $<); then touch $@; else git -C $(WORKDIR) am --abort; exit 1; fi
.PHONY: checkout
@ -60,4 +57,4 @@ format-patches: llama/patches
.PHONE: clean
clean: checkout
$(RM) llama/patches/.*.patched
$(RM) $(addsuffix ed, $(PATCHES))

View File

@ -61,8 +61,6 @@ Here are some example models that can be downloaded:
| QwQ | 32B | 20GB | `ollama run qwq` |
| DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` |
| DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` |
| Llama 4 | 109B | 67GB | `ollama run llama4:scout` |
| Llama 4 | 400B | 245GB | `ollama run llama4:maverick` |
| Llama 3.3 | 70B | 43GB | `ollama run llama3.3` |
| Llama 3.2 | 3B | 2.0GB | `ollama run llama3.2` |
| Llama 3.2 | 1B | 1.3GB | `ollama run llama3.2:1b` |
@ -79,7 +77,7 @@ Here are some example models that can be downloaded:
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
| LLaVA | 7B | 4.5GB | `ollama run llava` |
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
| Granite-3.2 | 8B | 4.9GB | `ollama run granite3.2` |
> [!NOTE]
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
@ -287,13 +285,13 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
- [Saddle](https://github.com/jikkuatwork/saddle)
- [TagSpaces](https://www.tagspaces.org) (A platform for file-based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
- [TagSpaces](https://www.tagspaces.org) (A platform for file based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
- [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
- [Ollamac](https://github.com/kevinhermawan/Ollamac)
- [big-AGI](https://github.com/enricoros/big-AGI)
- [big-AGI](https://github.com/enricoros/big-AGI/blob/main/docs/config-local-ollama.md)
- [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
- [Amica](https://github.com/semperai/amica)
- [chatd](https://github.com/BruceMacD/chatd)
@ -314,8 +312,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
- [IntelliBar](https://intellibar.app/) (AI-powered assistant for macOS)
- [Jirapt](https://github.com/AliAhmedNada/jirapt) (Jira Integration to generate issues, tasks, epics)
- [ojira](https://github.com/AliAhmedNada/ojira) (Jira chrome plugin to easily generate descriptions for tasks)
- [QA-Pilot](https://github.com/reid41/QA-Pilot) (Interactive chat tool that can leverage Ollama models for rapid understanding and navigation of GitHub code repositories)
- [ChatOllama](https://github.com/sugarforever/chat-ollama) (Open Source Chatbot based on Ollama with Knowledge Bases)
- [CRAG Ollama Chat](https://github.com/Nagi-ovo/CRAG-Ollama-Chat) (Simple Web Search with Corrective RAG)
@ -329,14 +325,14 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama)
- [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models)
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
- [Casibase](https://casibase.org) (An open source AI knowledge base and dialogue system combining the latest RAG, SSO, ollama support, and multiple large language models.)
- [Casibase](https://casibase.org) (An open source AI knowledge base and dialogue system combining the latest RAG, SSO, ollama support and multiple large language models.)
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
- [Shinkai Desktop](https://github.com/dcSpark/shinkai-apps) (Two click install Local AI using Ollama + Files + RAG)
- [AiLama](https://github.com/zeyoyt/ailama) (A Discord User App that allows you to interact with Ollama anywhere in Discord)
- [AiLama](https://github.com/zeyoyt/ailama) (A Discord User App that allows you to interact with Ollama anywhere in discord )
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
- [R2R](https://github.com/SciPhi-AI/R2R) (Open-source RAG engine)
- [Ollama-Kis](https://github.com/elearningshow/ollama-kis) (A simple easy-to-use GUI with sample custom LLM for Drivers Education)
- [Ollama-Kis](https://github.com/elearningshow/ollama-kis) (A simple easy to use GUI with sample custom LLM for Drivers Education)
- [OpenGPA](https://opengpa.org) (Open-source offline-first Enterprise Agentic Application)
- [Painting Droid](https://github.com/mateuszmigas/painting-droid) (Painting app with AI integrations)
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
@ -345,16 +341,16 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LLMStack](https://github.com/trypromptly/LLMStack) (No-code multi-agent framework to build LLM agents and workflows)
- [BoltAI for Mac](https://boltai.com) (AI Chat Client for Mac)
- [Harbor](https://github.com/av/harbor) (Containerized LLM Toolkit with Ollama as default backend)
- [PyGPT](https://github.com/szczyglis-dev/py-gpt) (AI desktop assistant for Linux, Windows, and Mac)
- [Alpaca](https://github.com/Jeffser/Alpaca) (An Ollama client application for Linux and macOS made with GTK4 and Adwaita)
- [PyGPT](https://github.com/szczyglis-dev/py-gpt) (AI desktop assistant for Linux, Windows and Mac)
- [Alpaca](https://github.com/Jeffser/Alpaca) (An Ollama client application for linux and macos made with GTK4 and Adwaita)
- [AutoGPT](https://github.com/Significant-Gravitas/AutoGPT/blob/master/docs/content/platform/ollama.md) (AutoGPT Ollama integration)
- [Go-CREW](https://www.jonathanhecl.com/go-crew/) (Powerful Offline RAG in Golang)
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot, and Ollama4j
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot and Ollama4j
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VSCode extension for multi-file/whole-repo coding
- [Claude Dev](https://github.com/saoudrizwan/claude-dev) - VSCode extension for multi-file/whole-repo coding
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
- [crewAI with Mesop](https://github.com/rapidarchitect/ollama-crew-mesop) (Mesop Web Interface to run crewAI with Ollama)
- [Tkinter-based client](https://github.com/chyok/ollama-gui) (Python tkinter-based Client for Ollama)
@ -372,7 +368,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
- [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol)
- [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app)
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard, and said in the meetings)
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard and said in the meetings)
- [Hexabot](https://github.com/hexastack/hexabot) (A conversational AI builder)
- [Reddit Rate](https://github.com/rapidarchitect/reddit_analyzer) (Search and Rate Reddit topics with a weighted summation)
- [OpenTalkGpt](https://github.com/adarshM84/OpenTalkGpt) (Chrome Extension to manage open-source models supported by Ollama, create custom models, and chat with models from a user-friendly UI)
@ -390,7 +386,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [ChibiChat](https://github.com/CosmicEventHorizon/ChibiChat) (Kotlin-based Android app to chat with Ollama and Koboldcpp API endpoints)
- [LocalLLM](https://github.com/qusaismael/localllm) (Minimal Web-App to run ollama models on it with a GUI)
- [Ollamazing](https://github.com/buiducnhat/ollamazing) (Web extension to run Ollama models)
- [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivalent endpoint with Ollama support for running locally)
- [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivent endpoint with Ollama support for running locally)
- [AntSK](https://github.com/AIDotNet/AntSK) (Out-of-the-box & Adaptable RAG Chatbot)
- [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot)
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
@ -398,13 +394,10 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
- [Flufy](https://github.com/Aharon-Bensadoun/Flufy) (A beautiful chat interface for interacting with Ollama's API. Built with React, TypeScript, and Material-UI.)
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
- [Ollamb](https://github.com/hengkysteen/ollamb) (Simple yet rich in features, cross-platform built with Flutter and designed for Ollama. Try the [web demo](https://hengkysteen.github.io/demo/ollamb/).)
- [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama)
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
### Cloud
@ -446,8 +439,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
- [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis.
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull, and download models from Ollama Registry in your terminal.
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull and download models from Ollama Registry in your terminal.
### Apple Vision Pro
@ -474,7 +466,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Libraries
- [LangChain](https://python.langchain.com/docs/integrations/chat/ollama/) and [LangChain.js](https://js.langchain.com/docs/integrations/chat/ollama/) with [example](https://js.langchain.com/docs/tutorials/local_rag/)
- [LangChain](https://python.langchain.com/docs/integrations/llms/ollama) and [LangChain.js](https://js.langchain.com/docs/integrations/chat/ollama/) with [example](https://js.langchain.com/docs/tutorials/local_rag/)
- [Firebase Genkit](https://firebase.google.com/docs/genkit/plugins/ollama)
- [crewAI](https://github.com/crewAIInc/crewAI)
- [Yacana](https://remembersoftwares.github.io/yacana/) (User-friendly multi-agent framework for brainstorming and executing predetermined flows with built-in tool integration)
@ -521,21 +513,20 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Swollama for Swift](https://github.com/marcusziade/Swollama) with [DocC](https://marcusziade.github.io/Swollama/documentation/swollama/)
- [GoLamify](https://github.com/prasad89/golamify)
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API)
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in unified API)
- [LlmTornado](https://github.com/lofcz/llmtornado) (C# library providing a unified interface for major FOSS & Commercial inference APIs)
- [Ollama for Zig](https://github.com/dravenk/ollama-zig)
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
- [Ollama for D](https://github.com/kassane/ollama-d)
- [OllamaPlusPlus](https://github.com/HardCodeDev777/OllamaPlusPlus) (Very simple C++ library for Ollama)
### Mobile
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Lightning-fast Cross-platform AI chat app with native UI for Android, iOS, and iPad)
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Lightning-fast Cross-platform AI chat app with native UI for Android, iOS and iPad)
- [Enchanted](https://github.com/AugustDev/enchanted)
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid)
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
- [Ollama Android Chat](https://github.com/sunshine0523/OllamaServer) (No need for Termux, start the Ollama service with one click on an Android device)
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
@ -559,7 +550,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Obsidian Local GPT plugin](https://github.com/pfrankov/obsidian-local-gpt)
- [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama)
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
- [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use Ollama as a copilot like GitHub Copilot)
- [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use ollama as a copilot like Github copilot)
- [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama)
- [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and Hugging Face)
- [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension)
@ -569,8 +560,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation)
- [ChatGPTBox: All in one browser extension](https://github.com/josStorer/chatGPTBox) with [Integrating Tutorial](https://github.com/josStorer/chatGPTBox/issues/616#issuecomment-1975186467)
- [Discord AI chat/moderation bot](https://github.com/rapmd73/Companion) Chat/moderation bot written in python. Uses Ollama to create personalities.
- [Headless Ollama](https://github.com/nischalj10/headless-ollama) (Scripts to automatically install ollama client & models on any OS for apps that depend on ollama server)
- [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front-end Open WebUI service.)
- [Headless Ollama](https://github.com/nischalj10/headless-ollama) (Scripts to automatically install ollama client & models on any OS for apps that depends on ollama server)
- [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front end Open WebUI service.)
- [node-red-contrib-ollama](https://github.com/jakubburkiewicz/node-red-contrib-ollama)
- [Local AI Helper](https://github.com/ivostoykov/localAI) (Chrome and Firefox extensions that enable interactions with the active tab and customisable API endpoints. Includes secure storage for user prompts.)
- [vnc-lm](https://github.com/jake83741/vnc-lm) (Discord bot for messaging with LLMs through Ollama and LiteLLM. Seamlessly move between local and flagship models.)
@ -584,7 +575,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai)
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
### Supported backends

View File

@ -1,6 +1,7 @@
package api
import (
"context"
"encoding/json"
"fmt"
"net/http"
@ -136,7 +137,7 @@ func TestClientStream(t *testing.T) {
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
var receivedChunks []ChatResponse
err := client.stream(t.Context(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error {
err := client.stream(context.Background(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error {
var resp ChatResponse
if err := json.Unmarshal(chunk, &resp); err != nil {
return fmt.Errorf("failed to unmarshal chunk: %w", err)
@ -222,7 +223,7 @@ func TestClientDo(t *testing.T) {
ID string `json:"id"`
Success bool `json:"success"`
}
err := client.do(t.Context(), http.MethodPost, "/v1/messages", nil, &resp)
err := client.do(context.Background(), http.MethodPost, "/v1/messages", nil, &resp)
if tc.wantErr != "" {
if err == nil {

View File

@ -76,7 +76,7 @@ type GenerateRequest struct {
// this request.
KeepAlive *Duration `json:"keep_alive,omitempty"`
// Images is an optional list of raw image bytes accompanying this
// Images is an optional list of base64-encoded images accompanying this
// request, for multimodal models.
Images []ImageData `json:"images,omitempty"`
@ -163,65 +163,19 @@ func (t *ToolCallFunctionArguments) String() string {
type Tool struct {
Type string `json:"type"`
Items any `json:"items,omitempty"`
Function ToolFunction `json:"function"`
}
// PropertyType can be either a string or an array of strings
type PropertyType []string
// UnmarshalJSON implements the json.Unmarshaler interface
func (pt *PropertyType) UnmarshalJSON(data []byte) error {
// Try to unmarshal as a string first
var s string
if err := json.Unmarshal(data, &s); err == nil {
*pt = []string{s}
return nil
}
// If that fails, try to unmarshal as an array of strings
var a []string
if err := json.Unmarshal(data, &a); err != nil {
return err
}
*pt = a
return nil
}
// MarshalJSON implements the json.Marshaler interface
func (pt PropertyType) MarshalJSON() ([]byte, error) {
if len(pt) == 1 {
// If there's only one type, marshal as a string
return json.Marshal(pt[0])
}
// Otherwise marshal as an array
return json.Marshal([]string(pt))
}
// String returns a string representation of the PropertyType
func (pt PropertyType) String() string {
if len(pt) == 0 {
return ""
}
if len(pt) == 1 {
return pt[0]
}
return fmt.Sprintf("%v", []string(pt))
}
type ToolFunction struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]struct {
Type PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
} `json:"parameters"`
}
@ -271,6 +225,9 @@ type Options struct {
RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
PresencePenalty float32 `json:"presence_penalty,omitempty"`
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
Mirostat int `json:"mirostat,omitempty"`
MirostatTau float32 `json:"mirostat_tau,omitempty"`
MirostatEta float32 `json:"mirostat_eta,omitempty"`
Stop []string `json:"stop,omitempty"`
}
@ -280,7 +237,12 @@ type Runner struct {
NumBatch int `json:"num_batch,omitempty"`
NumGPU int `json:"num_gpu,omitempty"`
MainGPU int `json:"main_gpu,omitempty"`
LowVRAM bool `json:"low_vram,omitempty"`
F16KV bool `json:"f16_kv,omitempty"` // Deprecated: This option is ignored
LogitsAll bool `json:"logits_all,omitempty"`
VocabOnly bool `json:"vocab_only,omitempty"`
UseMMap *bool `json:"use_mmap,omitempty"`
UseMLock bool `json:"use_mlock,omitempty"`
NumThread int `json:"num_thread,omitempty"`
}
@ -463,6 +425,13 @@ type ProcessModelResponse struct {
SizeVRAM int64 `json:"size_vram"`
}
type RetrieveModelResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
type TokenResponse struct {
Token string `json:"token"`
}
@ -645,6 +614,9 @@ func DefaultOptions() Options {
RepeatPenalty: 1.1,
PresencePenalty: 0.0,
FrequencyPenalty: 0.0,
Mirostat: 0,
MirostatTau: 5.0,
MirostatEta: 0.1,
Seed: -1,
Runner: Runner{
@ -653,6 +625,8 @@ func DefaultOptions() Options {
NumBatch: 512,
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
NumThread: 0, // let the runtime decide
LowVRAM: false,
UseMLock: false,
UseMMap: nil,
},
}

View File

@ -231,144 +231,3 @@ func TestMessage_UnmarshalJSON(t *testing.T) {
}
}
}
func TestToolFunction_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
input string
wantErr string
}{
{
name: "valid enum with same types",
input: `{
"name": "test",
"description": "test function",
"parameters": {
"type": "object",
"required": ["test"],
"properties": {
"test": {
"type": "string",
"description": "test prop",
"enum": ["a", "b", "c"]
}
}
}
}`,
wantErr: "",
},
{
name: "empty enum array",
input: `{
"name": "test",
"description": "test function",
"parameters": {
"type": "object",
"required": ["test"],
"properties": {
"test": {
"type": "string",
"description": "test prop",
"enum": []
}
}
}
}`,
wantErr: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var tf ToolFunction
err := json.Unmarshal([]byte(tt.input), &tf)
if tt.wantErr != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.wantErr)
} else {
require.NoError(t, err)
}
})
}
}
func TestPropertyType_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
input string
expected PropertyType
}{
{
name: "string type",
input: `"string"`,
expected: PropertyType{"string"},
},
{
name: "array of types",
input: `["string", "number"]`,
expected: PropertyType{"string", "number"},
},
{
name: "array with single type",
input: `["string"]`,
expected: PropertyType{"string"},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var pt PropertyType
if err := json.Unmarshal([]byte(test.input), &pt); err != nil {
t.Errorf("Unexpected error: %v", err)
}
if len(pt) != len(test.expected) {
t.Errorf("Length mismatch: got %v, expected %v", len(pt), len(test.expected))
}
for i, v := range pt {
if v != test.expected[i] {
t.Errorf("Value mismatch at index %d: got %v, expected %v", i, v, test.expected[i])
}
}
})
}
}
func TestPropertyType_MarshalJSON(t *testing.T) {
tests := []struct {
name string
input PropertyType
expected string
}{
{
name: "single type",
input: PropertyType{"string"},
expected: `"string"`,
},
{
name: "multiple types",
input: PropertyType{"string", "number"},
expected: `["string","number"]`,
},
{
name: "empty type",
input: PropertyType{},
expected: `[]`,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
data, err := json.Marshal(test.input)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if string(data) != test.expected {
t.Errorf("Marshaled data mismatch: got %v, expected %v", string(data), test.expected)
}
})
}
}

View File

@ -4,14 +4,20 @@ import (
"fmt"
"log/slog"
"os"
"path/filepath"
"strconv"
"strings"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/logutil"
)
func InitLogging() {
level := slog.LevelInfo
if envconfig.Debug() {
level = slog.LevelDebug
}
var logFile *os.File
var err error
// Detect if we're a GUI app on windows, and if not, send logs to console
@ -27,8 +33,20 @@ func InitLogging() {
return
}
}
handler := slog.NewTextHandler(logFile, &slog.HandlerOptions{
Level: level,
AddSource: true,
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
if attr.Key == slog.SourceKey {
source := attr.Value.Any().(*slog.Source)
source.File = filepath.Base(source.File)
}
return attr
},
})
slog.SetDefault(slog.New(handler))
slog.SetDefault(logutil.NewLogger(logFile, envconfig.LogLevel()))
slog.Info("ollama app started")
}

View File

@ -78,7 +78,7 @@ func BenchmarkColdStart(b *testing.B) {
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
ctx := b.Context()
ctx := context.Background()
// Set number of tokens as our throughput metric
b.SetBytes(int64(tt.maxTokens))
@ -113,7 +113,7 @@ func BenchmarkWarmStart(b *testing.B) {
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
ctx := b.Context()
ctx := context.Background()
// Pre-warm the model
warmup(client, m, tt.prompt, b)
@ -140,7 +140,7 @@ func setup(b *testing.B) *api.Client {
if err != nil {
b.Fatal(err)
}
if _, err := client.Show(b.Context(), &api.ShowRequest{Model: modelName(b)}); err != nil {
if _, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName(b)}); err != nil {
b.Fatalf("Model unavailable: %v", err)
}

View File

@ -31,7 +31,6 @@ import (
"github.com/olekukonko/tablewriter"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
"golang.org/x/term"
"github.com/ollama/ollama/api"
@ -42,7 +41,6 @@ import (
"github.com/ollama/ollama/runner"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version"
)
@ -108,7 +106,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
}
spinner.Stop()
req.Model = args[0]
req.Name = args[0]
quantize, _ := cmd.Flags().GetString("quantize")
if quantize != "" {
req.Quantize = quantize
@ -119,54 +117,34 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err
}
var g errgroup.Group
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
files := syncmap.NewSyncMap[string, string]()
for f, digest := range req.Files {
g.Go(func() error {
if len(req.Files) > 0 {
fileMap := map[string]string{}
for f, digest := range req.Files {
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
return err
}
// TODO: this is incorrect since the file might be in a subdirectory
// instead this should take the path relative to the model directory
// but the current implementation does not allow this
files.Store(filepath.Base(f), digest)
return nil
})
fileMap[filepath.Base(f)] = digest
}
req.Files = fileMap
}
adapters := syncmap.NewSyncMap[string, string]()
for f, digest := range req.Adapters {
g.Go(func() error {
if len(req.Adapters) > 0 {
fileMap := map[string]string{}
for f, digest := range req.Adapters {
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
return err
}
// TODO: same here
adapters.Store(filepath.Base(f), digest)
return nil
})
fileMap[filepath.Base(f)] = digest
}
req.Adapters = fileMap
}
if err := g.Wait(); err != nil {
return err
}
req.Files = files.Items()
req.Adapters = adapters.Items()
bars := make(map[string]*progress.Bar)
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
bar, ok := bars[resp.Digest]
if !ok {
msg := resp.Status
if msg == "" {
msg = fmt.Sprintf("pulling %s...", resp.Digest[7:19])
}
bar = progress.NewBar(msg, resp.Total, resp.Completed)
bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
@ -235,7 +213,7 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string, digest stri
}
}()
if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err
}
return digest, nil
@ -830,38 +808,13 @@ func PullHandler(cmd *cobra.Command, args []string) error {
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
if resp.Completed == 0 {
// This is the initial status update for the
// layer, which the server sends before
// beginning the download, for clients to
// compute total size and prepare for
// downloads, if needed.
//
// Skipping this here to avoid showing a 0%
// progress bar, which *should* clue the user
// into the fact that many things are being
// downloaded and that the current active
// download is not that last. However, in rare
// cases it seems to be triggering to some, and
// it isn't worth explaining, so just ignore
// and regress to the old UI that keeps giving
// you the "But wait, there is more!" after
// each "100% done" bar, which is "better."
return nil
}
if spinner != nil {
spinner.Stop()
}
bar, ok := bars[resp.Digest]
if !ok {
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
name = strings.TrimSpace(name)
if isDigest {
name = name[:min(12, len(name))]
}
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
@ -881,7 +834,11 @@ func PullHandler(cmd *cobra.Command, args []string) error {
}
request := api.PullRequest{Name: args[0], Insecure: insecure}
return client.Pull(cmd.Context(), &request, fn)
if err := client.Pull(cmd.Context(), &request, fn); err != nil {
return err
}
return nil
}
type generateContextKey string
@ -1424,6 +1381,7 @@ func NewCLI() *cobra.Command {
envVars["OLLAMA_NOPRUNE"],
envVars["OLLAMA_ORIGINS"],
envVars["OLLAMA_SCHED_SPREAD"],
envVars["OLLAMA_TMPDIR"],
envVars["OLLAMA_FLASH_ATTENTION"],
envVars["OLLAMA_KV_CACHE_TYPE"],
envVars["OLLAMA_LLM_LIBRARY"],

View File

@ -2,6 +2,7 @@ package cmd
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
@ -336,7 +337,7 @@ func TestDeleteHandler(t *testing.T) {
t.Cleanup(mockServer.Close)
cmd := &cobra.Command{}
cmd.SetContext(t.Context())
cmd.SetContext(context.TODO())
if err := DeleteHandler(cmd, []string{"test-model"}); err != nil {
t.Fatalf("DeleteHandler failed: %v", err)
}
@ -398,6 +399,11 @@ func TestGetModelfileName(t *testing.T) {
var expectedFilename string
if tt.fileExists {
tempDir, err := os.MkdirTemp("", "modelfiledir")
defer os.RemoveAll(tempDir)
if err != nil {
t.Fatalf("temp modelfile dir creation failed: %v", err)
}
var fn string
if tt.modelfileName != "" {
fn = tt.modelfileName
@ -405,11 +411,10 @@ func TestGetModelfileName(t *testing.T) {
fn = "Modelfile"
}
tempFile, err := os.CreateTemp(t.TempDir(), fn)
tempFile, err := os.CreateTemp(tempDir, fn)
if err != nil {
t.Fatalf("temp modelfile creation failed: %v", err)
}
defer tempFile.Close()
expectedFilename = tempFile.Name()
err = cmd.Flags().Set("file", expectedFilename)
@ -524,7 +529,7 @@ func TestPushHandler(t *testing.T) {
cmd := &cobra.Command{}
cmd.Flags().Bool("insecure", false, "")
cmd.SetContext(t.Context())
cmd.SetContext(context.TODO())
// Redirect stderr to capture progress output
oldStderr := os.Stderr
@ -629,7 +634,7 @@ func TestListHandler(t *testing.T) {
t.Setenv("OLLAMA_HOST", mockServer.URL)
cmd := &cobra.Command{}
cmd.SetContext(t.Context())
cmd.SetContext(context.TODO())
// Capture stdout
oldStdout := os.Stdout
@ -684,7 +689,7 @@ func TestCreateHandler(t *testing.T) {
return
}
if req.Model != "test-model" {
if req.Name != "test-model" {
t.Errorf("expected model name 'test-model', got %s", req.Name)
}
@ -724,7 +729,7 @@ func TestCreateHandler(t *testing.T) {
}))
t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close)
tempFile, err := os.CreateTemp(t.TempDir(), "modelfile")
tempFile, err := os.CreateTemp("", "modelfile")
if err != nil {
t.Fatal(err)
}
@ -744,7 +749,7 @@ func TestCreateHandler(t *testing.T) {
}
cmd.Flags().Bool("insecure", false, "")
cmd.SetContext(t.Context())
cmd.SetContext(context.TODO())
// Redirect stderr to capture progress output
oldStderr := os.Stderr

View File

@ -44,7 +44,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
if opts.MultiModal {
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, or .webp images.\n", filepath.FromSlash("/path/to/file"))
fmt.Fprintf(os.Stderr, "Use %s to include .jpg or .png images.\n", filepath.FromSlash("/path/to/file"))
}
fmt.Fprintln(os.Stderr, "")
@ -503,7 +503,6 @@ func normalizeFilePath(fp string) string {
"\\\\", "\\", // Escaped backslash
"\\*", "*", // Escaped asterisk
"\\?", "?", // Escaped question mark
"\\~", "~", // Escaped tilde
).Replace(fp)
}
@ -511,7 +510,7 @@ func extractFileNames(input string) []string {
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
// and followed by more characters and a file extension
// This will capture non filename strings, but we'll check for file existence to remove mismatches
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp)\b`
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png)\b`
re := regexp.MustCompile(regexPattern)
return re.FindAllString(input, -1)
@ -531,8 +530,6 @@ func extractFileData(input string) (string, []api.ImageData, error) {
return "", imgs, err
}
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
input = strings.ReplaceAll(input, "'"+nfp+"'", "")
input = strings.ReplaceAll(input, "'"+fp+"'", "")
input = strings.ReplaceAll(input, fp, "")
imgs = append(imgs, data)
}
@ -553,7 +550,7 @@ func getImageData(filePath string) ([]byte, error) {
}
contentType := http.DetectContentType(buf)
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp"}
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
if !slices.Contains(allowedTypes, contentType) {
return nil, fmt.Errorf("invalid image type: %s", contentType)
}

View File

@ -1,8 +1,6 @@
package cmd
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
@ -12,17 +10,14 @@ func TestExtractFilenames(t *testing.T) {
// Unix style paths
input := ` some preamble
./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2 ./1.svg
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.JPG
/unescaped space /six.webp inbetween6 /valid\ path/dir/seven.WEBP`
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.JPG`
res := extractFileNames(input)
assert.Len(t, res, 7)
assert.Len(t, res, 5)
assert.Contains(t, res[0], "one.png")
assert.Contains(t, res[1], "two.jpg")
assert.Contains(t, res[2], "three.jpeg")
assert.Contains(t, res[3], "four.png")
assert.Contains(t, res[4], "five.JPG")
assert.Contains(t, res[5], "six.webp")
assert.Contains(t, res[6], "seven.WEBP")
assert.NotContains(t, res[4], '"')
assert.NotContains(t, res, "inbetween1")
assert.NotContains(t, res, "./1.svg")
@ -33,12 +28,10 @@ func TestExtractFilenames(t *testing.T) {
/absolute/nospace/three.jpeg inbetween3 /absolute/with space/four.png inbetween4
./relative\ path/five.JPG inbetween5 "./relative with/spaces/six.png inbetween6
d:\path with\spaces\seven.JPEG inbetween7 c:\users\jdoe\eight.png inbetween8
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.PNG
c:/users/jdoe/eleven.webp inbetween11 c:/program files/someplace/twelve.WebP inbetween12
d:\path with\spaces\thirteen.WEBP some ending
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.PNG some ending
`
res = extractFileNames(input)
assert.Len(t, res, 13)
assert.Len(t, res, 10)
assert.NotContains(t, res, "inbetween2")
assert.Contains(t, res[0], "one.png")
assert.Contains(t, res[0], "c:")
@ -56,31 +49,4 @@ d:\path with\spaces\thirteen.WEBP some ending
assert.Contains(t, res[8], "d:")
assert.Contains(t, res[9], "ten.PNG")
assert.Contains(t, res[9], "E:")
assert.Contains(t, res[10], "eleven.webp")
assert.Contains(t, res[10], "c:")
assert.Contains(t, res[11], "twelve.WebP")
assert.Contains(t, res[11], "c:")
assert.Contains(t, res[12], "thirteen.WEBP")
assert.Contains(t, res[12], "d:")
}
// Ensure that file paths wrapped in single quotes are removed with the quotes.
func TestExtractFileDataRemovesQuotedFilepath(t *testing.T) {
dir := t.TempDir()
fp := filepath.Join(dir, "img.jpg")
data := make([]byte, 600)
copy(data, []byte{
0xff, 0xd8, 0xff, 0xe0, 0x00, 0x10, 'J', 'F', 'I', 'F',
0x00, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0xff, 0xd9,
})
if err := os.WriteFile(fp, data, 0o600); err != nil {
t.Fatalf("failed to write test image: %v", err)
}
input := "before '" + fp + "' after"
cleaned, imgs, err := extractFileData(input)
assert.NoError(t, err)
assert.Len(t, imgs, 1)
assert.Equal(t, cleaned, "before after")
}

View File

@ -1,26 +1,25 @@
package convert
import (
"cmp"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"log/slog"
"os"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type ModelParameters struct {
Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"`
Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"`
TextModel TextParameters `json:"text_config"`
}
TextModel struct {
VocabSize uint32 `json:"vocab_size"`
} `json:"text_config"`
type TextParameters struct {
VocabSize uint32 `json:"vocab_size"`
}
type AdapterParameters struct {
@ -85,17 +84,27 @@ func (ModelParameters) specialTokenTypes() []string {
}
}
func (ModelParameters) writeFile(ws io.WriteSeeker, kv ggml.KV, ts []ggml.Tensor) error {
return ggml.WriteGGUF(ws, kv, ts)
}
func (AdapterParameters) writeFile(ws io.WriteSeeker, kv ggml.KV, ts []ggml.Tensor) error {
return ggml.WriteGGUF(ws, kv, ts)
}
type ModelConverter interface {
// KV maps parameters to LLM key-values
KV(*Tokenizer) ggml.KV
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor
Tensors([]Tensor) []ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
Replacements() []string
// specialTokenTypes returns any special token types the model uses
specialTokenTypes() []string
// writeFile writes the model to the provided io.WriteSeeker
writeFile(io.WriteSeeker, ggml.KV, []ggml.Tensor) error
}
type moreParser interface {
@ -106,13 +115,15 @@ type AdapterConverter interface {
// KV maps parameters to LLM key-values
KV(ggml.KV) ggml.KV
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor
Tensors([]Tensor) []ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
Replacements() []string
writeFile(io.WriteSeeker, ggml.KV, []ggml.Tensor) error
}
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV ggml.KV) error {
bts, err := fs.ReadFile(fsys, "adapter_config.json")
if err != nil {
return err
@ -147,14 +158,14 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
return err
}
return writeFile(f, conv.KV(baseKV), conv.Tensors(ts))
return conv.writeFile(ws, conv.KV(baseKV), conv.Tensors(ts))
}
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error {
func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
bts, err := fs.ReadFile(fsys, "config.json")
if err != nil {
return err
@ -173,10 +184,6 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
switch p.Architectures[0] {
case "LlamaForCausalLM":
conv = &llamaModel{}
case "MllamaForConditionalGeneration":
conv = &mllamaModel{}
case "Llama4ForConditionalGeneration":
conv = &llama4Model{}
case "Mistral3ForConditionalGeneration":
conv = &mistral3Model{}
case "MixtralForCausalLM":
@ -191,8 +198,6 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &phi3Model{}
case "Qwen2ForCausalLM":
conv = &qwen2Model{}
case "Qwen2_5_VLForConditionalGeneration":
conv = &qwen25VLModel{}
case "BertModel":
conv = &bertModel{}
case "CohereForCausalLM":
@ -216,22 +221,24 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
return err
}
vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize))
vocabSize := int(p.VocabSize)
if vocabSize == 0 {
tVocabSize := int(p.TextModel.VocabSize)
vocabSize = tVocabSize
}
switch {
case vocabSize == 0:
slog.Debug("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens))
slog.Warn("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens))
case vocabSize > len(t.Vocabulary.Tokens):
slog.Debug("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
for i := range vocabSize - len(t.Vocabulary.Tokens) {
t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i))
t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1)
t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined)
}
case vocabSize < len(t.Vocabulary.Tokens):
slog.Debug("vocabulary is larger than expected", "want", vocabSize, "got", len(t.Vocabulary.Tokens))
p.VocabSize = uint32(len(t.Vocabulary.Tokens))
p.TextModel.VocabSize = uint32(len(t.Vocabulary.Tokens))
return fmt.Errorf("vocabulary is larger than expected '%d' instead of '%d'", len(t.Vocabulary.Tokens), vocabSize)
default:
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
}
@ -241,13 +248,5 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
return err
}
return writeFile(f, conv.KV(t), conv.Tensors(ts))
}
func writeFile(f *os.File, kv ggml.KV, ts []*ggml.Tensor) error {
for i := range ts {
ts[i].Shape = slices.Clone(ts[i].Shape)
slices.Reverse(ts[i].Shape)
}
return ggml.WriteGGUF(f, kv, ts)
return conv.writeFile(ws, conv.KV(t), conv.Tensors(ts))
}

View File

@ -132,8 +132,8 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV {
return kv
}
func (p *bertModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
func (p *bertModel) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
for _, t := range ts {
if slices.Contains([]string{
"embeddings.position_ids",
@ -143,7 +143,7 @@ func (p *bertModel) Tensors(ts []Tensor) []*ggml.Tensor {
continue
}
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),

View File

@ -43,10 +43,10 @@ func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
return kv
}
func (p *commandrModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
func (p *commandrModel) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
for _, t := range ts {
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),

View File

@ -42,14 +42,14 @@ func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
return kv
}
func (p *gemmaModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
func (p *gemmaModel) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
for _, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") && strings.HasSuffix(t.Name(), "_norm.weight") {
t.SetRepacker(p.addOne)
}
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),

View File

@ -21,8 +21,8 @@ func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
return kv
}
func (p *gemma2Adapter) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
func (p *gemma2Adapter) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
for _, t := range ts {
shape := t.Shape()
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
@ -31,7 +31,7 @@ func (p *gemma2Adapter) Tensors(ts []Tensor) []*ggml.Tensor {
t.SetRepacker(p.repack)
}
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),

View File

@ -28,12 +28,12 @@ type llamaModel struct {
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
RopeScaling struct {
Type string `json:"type"`
RopeType string `json:"rope_type"`
Factor float32 `json:"factor"`
LowFrequencyFactor float32 `json:"low_freq_factor"`
HighFrequencyFactor float32 `json:"high_freq_factor"`
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
Type string `json:"type"`
RopeType string `json:"rope_type"`
Factor float32 `json:"factor"`
LowFrequencyFactor float32 `json:"low_freq_factor"`
HighFrequencyFactor float32 `json:"high_freq_factor"`
OriginalMaxPositionalEmbeddings uint32 `json:"original_max_positional_embeddings"`
factors ropeFactor
} `json:"rope_scaling"`
@ -42,8 +42,6 @@ type llamaModel struct {
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
NormEpsilon float32 `json:"norm_epsilon"`
HeadDim uint32 `json:"head_dim"`
skipRepack bool
}
var _ ModelConverter = (*llamaModel)(nil)
@ -72,10 +70,6 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
kv["llama.rope.dimension_count"] = p.HiddenSize / headCount
}
if p.HeadDim > 0 {
kv["llama.attention.head_dim"] = p.HeadDim
}
if p.RopeTheta > 0 {
kv["llama.rope.freq_base"] = p.RopeTheta
}
@ -90,7 +84,7 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
factorLow := cmp.Or(p.RopeScaling.LowFrequencyFactor, 1.0)
factorHigh := cmp.Or(p.RopeScaling.HighFrequencyFactor, 4.0)
original := cmp.Or(p.RopeScaling.OriginalMaxPositionEmbeddings, 8192)
original := cmp.Or(p.RopeScaling.OriginalMaxPositionalEmbeddings, 8192)
lambdaLow := float32(original) / factorLow
lambdaHigh := float32(original) / factorHigh
@ -126,11 +120,11 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
return kv
}
func (p *llamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
func (p *llamaModel) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
if p.RopeScaling.factors != nil {
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: "rope_freqs.weight",
Kind: 0,
Shape: []uint64{uint64(len(p.RopeScaling.factors))},
@ -139,13 +133,12 @@ func (p *llamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
}
for _, t := range ts {
if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") {
if !p.skipRepack {
t.SetRepacker(p.repack)
}
if strings.HasSuffix(t.Name(), "attn_q.weight") ||
strings.HasSuffix(t.Name(), "attn_k.weight") {
t.SetRepacker(p.repack)
}
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),

View File

@ -1,169 +0,0 @@
package convert
import (
"slices"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type llama4Model struct {
ModelParameters
TextModel struct {
llamaModel
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
NumLocalExperts uint32 `json:"num_local_experts"`
InterleaveMOELayerStep uint32 `json:"interleave_moe_layer_step"`
UseQKNorm bool `json:"use_qk_norm"`
IntermediateSizeMLP uint32 `json:"intermediate_size_mlp"`
AttentionChunkSize uint32 `json:"attention_chunk_size"`
} `json:"text_config"`
VisionModel struct {
NumHiddenLayers uint32 `json:"num_hidden_layers"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
ImageSize uint32 `json:"image_size"`
PatchSize uint32 `json:"patch_size"`
RopeTheta float32 `json:"rope_theta"`
NormEpsilon float32 `json:"norm_eps"`
PixelShuffleRatio float32 `json:"pixel_shuffle_ratio"`
} `json:"vision_config"`
}
// KV implements ModelConverter.
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama4"
for k, v := range p.TextModel.KV(t) {
if strings.HasPrefix(k, "llama.") {
kv[strings.ReplaceAll(k, "llama.", "llama4.")] = v
}
}
kv["llama4.feed_forward_length"] = p.TextModel.IntermediateSizeMLP
kv["llama4.expert_feed_forward_length"] = p.TextModel.IntermediateSize
kv["llama4.expert_count"] = p.TextModel.NumLocalExperts
kv["llama4.expert_used_count"] = p.TextModel.NumExpertsPerToken
kv["llama4.interleave_moe_layer_step"] = p.TextModel.InterleaveMOELayerStep
kv["llama4.use_qk_norm"] = p.TextModel.UseQKNorm
kv["llama4.attention.chunk_size"] = p.TextModel.AttentionChunkSize
kv["llama4.vision.block_count"] = p.VisionModel.NumHiddenLayers
kv["llama4.vision.embedding_length"] = p.VisionModel.HiddenSize
kv["llama4.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
kv["llama4.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
kv["llama4.vision.image_size"] = p.VisionModel.ImageSize
kv["llama4.vision.patch_size"] = p.VisionModel.PatchSize
kv["llama4.vision.rope.freq_base"] = p.VisionModel.RopeTheta
kv["llama4.vision.layer_norm_epsilon"] = p.VisionModel.NormEpsilon
kv["llama4.vision.pixel_shuffle_ratio"] = p.VisionModel.PixelShuffleRatio
return kv
}
// Replacements implements ModelConverter.
func (p *llama4Model) Replacements() []string {
return append(
p.TextModel.Replacements(),
"language_model.", "",
"vision_model", "v",
"multi_modal_projector", "mm",
"feed_forward.down_proj", "ffn_down",
"feed_forward.up_proj", "ffn_up",
"feed_forward.gate_proj", "ffn_gate",
"feed_forward.", "ffn_",
"shared_expert.down_proj", "down_shexp",
"shared_expert.gate_proj", "gate_shexp",
"shared_expert.up_proj", "up_shexp",
"experts.down_proj", "down_exps.weight",
"experts.gate_up_proj", "gate_up_exps.weight",
"router", "gate_inp",
"patch_embedding.linear", "patch_embedding",
)
}
// Tensors implements ModelConverter.
func (p *llama4Model) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
var textTensors []Tensor
for _, t := range ts {
if strings.HasPrefix(t.Name(), "v.") || strings.HasPrefix(t.Name(), "mm.") {
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
} else if strings.Contains(t.Name(), "ffn_gate_up_exps") {
// gate and up projectors are fused
// dims[1], dims[2] must be swapped
// [experts, hidden_size, intermediate_size * 2] --> [experts, intermediate_size, hidden_size]
halfDim := int(t.Shape()[2]) / 2
newShape := slices.Clone(t.Shape())
newShape[1], newShape[2] = newShape[2]/2, newShape[1]
for i, name := range []string{"ffn_gate_exps", "ffn_up_exps"} {
// clone tensor since we need separate repackers
tt := t.Clone()
tt.SetRepacker(p.repack(nil, nil, tensor.S(i*halfDim, (i+1)*halfDim)))
out = append(out, &ggml.Tensor{
Name: strings.ReplaceAll(tt.Name(), "ffn_gate_up_exps", name),
Kind: tt.Kind(),
Shape: newShape,
WriterTo: tt,
})
}
} else if strings.Contains(t.Name(), "ffn_down_exps") {
// dims[1], dims[2] must be swapped
// [experts, intermediate_size, hidden_size] --> [experts, hidden_size, intermediate_size]
t.SetRepacker(p.repack())
newShape := slices.Clone(t.Shape())
newShape[1], newShape[2] = newShape[2], newShape[1]
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: newShape,
WriterTo: t,
})
} else {
textTensors = append(textTensors, t)
}
}
p.TextModel.skipRepack = true
out = append(out, p.TextModel.Tensors(textTensors)...)
return out
}
func (p *llama4Model) repack(slice ...tensor.Slice) Repacker {
return func(name string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i, dim := range shape {
dims[i] = int(dim)
}
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
t, err := t.Slice(slice...)
if err != nil {
return nil, err
}
if err := t.T(0, 2, 1); err != nil {
return nil, err
}
t = tensor.Materialize(t)
// flatten tensor so it can be return as a vector
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(t.(*tensor.Dense))
}
}

View File

@ -29,8 +29,8 @@ func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
return kv
}
func (p *llamaAdapter) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
func (p *llamaAdapter) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
for _, t := range ts {
shape := t.Shape()
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
@ -41,7 +41,7 @@ func (p *llamaAdapter) Tensors(ts []Tensor) []*ggml.Tensor {
t.SetRepacker(p.repack)
}
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: shape,

View File

@ -89,8 +89,8 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
return kv
}
func (p *mistral3Model) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
func (p *mistral3Model) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
for _, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") {
@ -100,7 +100,7 @@ func (p *mistral3Model) Tensors(ts []Tensor) []*ggml.Tensor {
}
}
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),

View File

@ -29,7 +29,7 @@ func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
return kv
}
func (p *mixtralModel) Tensors(ts []Tensor) []*ggml.Tensor {
func (p *mixtralModel) Tensors(ts []Tensor) []ggml.Tensor {
oldnew := []string{
"model.layers", "blk",
"w1", "ffn_gate_exps",
@ -56,10 +56,10 @@ func (p *mixtralModel) Tensors(ts []Tensor) []*ggml.Tensor {
return true
})
var out []*ggml.Tensor
var out []ggml.Tensor
for n, e := range experts {
// TODO(mxyng): sanity check experts
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: n,
Kind: e[0].Kind(),
Shape: append([]uint64{uint64(len(e))}, e[0].Shape()...),

View File

@ -1,160 +0,0 @@
package convert
import (
"strings"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
)
type mllamaModel struct {
ModelParameters
TextModel struct {
llamaModel
CrossAttentionLayers []int32 `json:"cross_attention_layers"`
} `json:"text_config"`
VisionModel struct {
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumGlobalLayers uint32 `json:"num_global_layers"`
IntermediateLayersIndices []int32 `json:"intermediate_layers_indices"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
AttentionHeads uint32 `json:"attention_heads"`
ImageSize uint32 `json:"image_size"`
PatchSize uint32 `json:"patch_size"`
NumChannels uint32 `json:"num_channels"`
MaxNumTiles uint32 `json:"max_num_tiles"`
NormEpsilon float32 `json:"norm_eps"`
RopeTheta float32 `json:"rope.freq_base"`
} `json:"vision_config"`
}
func (m *mllamaModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "mllama"
for k, v := range m.TextModel.KV(t) {
if strings.HasPrefix(k, "llama.") {
kv[strings.ReplaceAll(k, "llama.", "mllama.")] = v
}
}
kv["mllama.attention.cross_attention_layers"] = m.TextModel.CrossAttentionLayers
kv["mllama.vision.block_count"] = m.VisionModel.NumHiddenLayers
kv["mllama.vision.global.block_count"] = m.VisionModel.NumGlobalLayers
kv["mllama.vision.intermediate_layers_indices"] = m.VisionModel.IntermediateLayersIndices
kv["mllama.vision.embedding_length"] = m.VisionModel.HiddenSize
kv["mllama.vision.feed_forward_length"] = m.VisionModel.IntermediateSize
kv["mllama.vision.attention.head_count"] = m.VisionModel.AttentionHeads
kv["mllama.vision.attention.layer_norm_epsilon"] = m.VisionModel.NormEpsilon
kv["mllama.vision.image_size"] = m.VisionModel.ImageSize
kv["mllama.vision.patch_size"] = m.VisionModel.PatchSize
kv["mllama.vision.max_num_tiles"] = m.VisionModel.MaxNumTiles
kv["mllama.vision.num_channels"] = m.VisionModel.NumChannels
return kv
}
func (m *mllamaModel) Replacements() []string {
return append(
m.TextModel.Replacements(),
"language_model.", "",
"gate_attn", "attn_gate",
"gate_ffn", "ffn_gate",
"cross_attn.", "cross_attn_",
"vision_model", "v",
"class_embedding", "class_embd",
"patch_embedding", "patch_embd",
"gated_positional_embedding.tile_embedding", "tile_position_embd",
"gated_positional_embedding.embedding", "position_embd.weight",
"gated_positional_embedding", "position_embd",
"embedding.weight", "weight",
"pre_tile_positional_embedding", "pre_tile_position_embd",
"post_tile_positional_embedding", "post_tile_position_embd",
"layernorm_pre", "pre_ln",
"layernorm_post", "post_ln",
"global_transformer.layers", "global.blk",
"transformer.layers", "blk",
"mlp.fc1", "ffn_up",
"mlp.fc2", "ffn_down",
"multi_modal_projector", "mm.0",
)
}
func (m *mllamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
var text []Tensor
for _, t := range ts {
if t.Name() == "v.position_embd.gate" {
for _, name := range []string{"v.position_embd.gate", "v.tile_position_embd.gate"} {
tt := t.Clone()
tt.SetRepacker(m.repack(name))
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: tt,
})
}
} else if t.Name() == "v.pre_tile_position_embd.gate" || t.Name() == "v.post_tile_position_embd.gate" {
t.SetRepacker(m.repack(t.Name()))
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
} else if strings.HasPrefix(t.Name(), "v.") || strings.HasPrefix(t.Name(), "mm.") {
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
} else {
text = append(text, t)
}
}
return append(out, m.TextModel.Tensors(text)...)
}
func (m *mllamaModel) repack(name string) Repacker {
return func(_ string, data []float32, shape []uint64) (_ []float32, err error) {
dims := make([]int, len(shape))
for i, dim := range shape {
dims[i] = int(dim)
}
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
t, err = tensor.Tanh(t)
if err != nil {
return nil, err
}
if name == "v.position_embd.gate" {
t, err = tensor.Sub(float32(1), t)
if err != nil {
return nil, err
}
}
t = tensor.Materialize(t)
// flatten tensor so it can be return as a vector
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(t.(*tensor.Dense))
}
}

View File

@ -68,19 +68,19 @@ func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
return kv
}
func (p *phi3Model) Tensors(ts []Tensor) []*ggml.Tensor {
func (p *phi3Model) Tensors(ts []Tensor) []ggml.Tensor {
var addRopeFactors sync.Once
out := make([]*ggml.Tensor, 0, len(ts)+2)
out := make([]ggml.Tensor, 0, len(ts)+2)
for _, t := range ts {
if strings.HasPrefix(t.Name(), "blk.0.") {
addRopeFactors.Do(func() {
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: "rope_factors_long.weight",
Kind: 0,
Shape: []uint64{uint64(len(p.RopeScaling.LongFactor))},
WriterTo: p.RopeScaling.LongFactor,
}, &ggml.Tensor{
}, ggml.Tensor{
Name: "rope_factors_short.weight",
Kind: 0,
Shape: []uint64{uint64(len(p.RopeScaling.ShortFactor))},
@ -89,7 +89,7 @@ func (p *phi3Model) Tensors(ts []Tensor) []*ggml.Tensor {
})
}
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
@ -118,5 +118,6 @@ func (p *phi3Model) Replacements() []string {
type ropeFactor []float32
func (r ropeFactor) WriteTo(w io.Writer) (int64, error) {
return 0, binary.Write(w, binary.LittleEndian, r)
err := binary.Write(w, binary.LittleEndian, r)
return 0, err
}

View File

@ -15,7 +15,6 @@ type qwen2Model struct {
Type string `json:"type"`
Factor ropeFactor `json:"factor"`
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
MropeSection []int32 `json:"mrope_section"`
} `json:"rope_scaling"`
RMSNormEPS float32 `json:"rms_norm_eps"`
}
@ -40,18 +39,16 @@ func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
case "yarn":
kv["qwen2.rope.scaling.type"] = q.RopeScaling.Type
kv["qwen2.rope.scaling.factor"] = q.RopeScaling.Factor
case "mrope", "default":
kv["qwen2.rope.mrope_section"] = q.RopeScaling.MropeSection
default:
panic("unknown rope scaling type")
}
return kv
}
func (q *qwen2Model) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
func (q *qwen2Model) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
for _, t := range ts {
out = append(out, &ggml.Tensor{
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),

View File

@ -1,102 +0,0 @@
package convert
import (
"cmp"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type qwen25VLModel struct {
qwen2Model
VisionModel struct {
Depth uint32 `json:"depth"`
HiddenSize uint32 `json:"hidden_size"`
NumHeads uint32 `json:"num_heads"`
InChannels uint32 `json:"in_chans"`
PatchSize uint32 `json:"patch_size"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
SpatialPatchSize uint32 `json:"spatial_patch_size"`
WindowSize uint32 `json:"window_size"`
RMSNormEps float32 `json:"layer_norm_epsilon"`
RopeTheta float32 `json:"rope_theta"`
FullAttentionBlocks []int32 `json:"fullatt_block_indexes"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
} `json:"vision_config"`
}
var _ ModelConverter = (*qwen25VLModel)(nil)
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen25vl"
for k, v := range q.qwen2Model.KV(t) {
if strings.HasPrefix(k, "qwen2.") {
kv[strings.Replace(k, "qwen2.", "qwen25vl.", 1)] = v
}
}
if q.VisionModel.FullAttentionBlocks == nil {
kv["qwen25vl.vision.fullatt_block_indexes"] = []int32{7, 15, 23, 31}
}
kv["qwen25vl.vision.block_count"] = cmp.Or(q.VisionModel.Depth, 32)
kv["qwen25vl.vision.embedding_length"] = q.VisionModel.HiddenSize
kv["qwen25vl.vision.attention.head_count"] = cmp.Or(q.VisionModel.NumHeads, 16)
kv["qwen25vl.vision.num_channels"] = q.VisionModel.InChannels
kv["qwen25vl.vision.patch_size"] = cmp.Or(q.VisionModel.PatchSize, 14)
kv["qwen25vl.vision.spatial_merge_size"] = cmp.Or(q.VisionModel.SpatialMergeSize, 2)
kv["qwen25vl.vision.spatial_patch_size"] = q.VisionModel.SpatialPatchSize
kv["qwen25vl.vision.window_size"] = cmp.Or(q.VisionModel.WindowSize, 112)
kv["qwen25vl.vision.attention.layer_norm_epsilon"] = cmp.Or(q.VisionModel.RMSNormEps, 1e-6)
kv["qwen25vl.vision.rope.freq_base"] = cmp.Or(q.VisionModel.RopeTheta, 1e4)
kv["qwen25vl.vision.fullatt_block_indexes"] = q.VisionModel.FullAttentionBlocks
kv["qwen25vl.vision.temporal_patch_size"] = cmp.Or(q.VisionModel.TemporalPatchSize, 2)
return kv
}
func (q *qwen25VLModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
for _, t := range ts {
if strings.Contains(t.Name(), "patch_embed.proj") {
for t := range splitDim(t, 2,
strings.NewReplacer("patch_embed.proj", "patch_embd_0"),
strings.NewReplacer("patch_embed.proj", "patch_embd_1"),
) {
t.Shape = slices.DeleteFunc(t.Shape, func(i uint64) bool { return i == 1 })
out = append(out, t)
}
} else if strings.Contains(t.Name(), "attn.qkv") {
out = append(out, slices.Collect(splitDim(t, 0,
strings.NewReplacer("attn.qkv", "attn_q"),
strings.NewReplacer("attn.qkv", "attn_k"),
strings.NewReplacer("attn.qkv", "attn_v"),
))...)
} else {
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
}
return out
}
func (p *qwen25VLModel) Replacements() []string {
return append(
p.qwen2Model.Replacements(),
"visual", "v",
"blocks", "blk",
"attn.proj", "attn_out",
"norm1", "ln1",
"norm2", "ln2",
)
}

View File

@ -11,6 +11,7 @@ import (
"io"
"io/fs"
"log/slog"
"math"
"os"
"path/filepath"
"slices"
@ -47,7 +48,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
}
t.Cleanup(func() { r.Close() })
m, _, err := ggml.Decode(r, -1)
m, _, err := ggml.Decode(r, math.MaxInt)
if err != nil {
t.Fatal(err)
}
@ -130,7 +131,6 @@ func TestConvertModel(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer expectFile.Close()
var expect map[string]string
if err := json.NewDecoder(expectFile).Decode(&expect); err != nil {
@ -332,7 +332,7 @@ func TestConvertAdapter(t *testing.T) {
}
defer r.Close()
m, _, err := ggml.Decode(r, -1)
m, _, err := ggml.Decode(r, math.MaxInt)
if err != nil {
t.Fatal(err)
}

58
convert/fs.go Normal file
View File

@ -0,0 +1,58 @@
package convert
import (
"archive/zip"
"errors"
"io"
"io/fs"
"os"
"path/filepath"
)
type ZipReader struct {
r *zip.Reader
p string
// limit is the maximum size of a file that can be read directly
// from the zip archive. Files larger than this size will be extracted
limit int64
}
func NewZipReader(r *zip.Reader, p string, limit int64) fs.FS {
return &ZipReader{r, p, limit}
}
func (z *ZipReader) Open(name string) (fs.File, error) {
r, err := z.r.Open(name)
if err != nil {
return nil, err
}
defer r.Close()
if fi, err := r.Stat(); err != nil {
return nil, err
} else if fi.Size() < z.limit {
return r, nil
}
if !filepath.IsLocal(name) {
return nil, zip.ErrInsecurePath
}
n := filepath.Join(z.p, name)
if _, err := os.Stat(n); errors.Is(err, os.ErrNotExist) {
w, err := os.Create(n)
if err != nil {
return nil, err
}
defer w.Close()
if _, err := io.Copy(w, r); err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}
return os.Open(n)
}

View File

@ -11,15 +11,14 @@ type Tensor interface {
Name() string
Shape() []uint64
Kind() uint32
SetRepacker(Repacker)
SetRepacker(repacker)
WriteTo(io.Writer) (int64, error)
Clone() Tensor
}
type tensorBase struct {
name string
shape []uint64
repacker Repacker
name string
shape []uint64
repacker
}
func (t tensorBase) Name() string {
@ -37,11 +36,7 @@ const (
func (t tensorBase) Kind() uint32 {
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
t.name == "token_types.weight" ||
t.name == "v.positional_embedding_vlm" ||
t.name == "v.tile_position_embd.weight" ||
t.name == "v.pre_tile_position_embd.weight" ||
t.name == "v.post_tile_position_embd.weight" {
t.name == "token_types.weight" {
// these tensors are always F32
return 0
}
@ -56,11 +51,11 @@ func (t tensorBase) Kind() uint32 {
}
}
func (t *tensorBase) SetRepacker(fn Repacker) {
func (t *tensorBase) SetRepacker(fn repacker) {
t.repacker = fn
}
type Repacker func(string, []float32, []uint64) ([]float32, error)
type repacker func(string, []float32, []uint64) ([]float32, error)
func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) {
patterns := []struct {

View File

@ -94,21 +94,6 @@ type safetensor struct {
*tensorBase
}
func (st safetensor) Clone() Tensor {
return &safetensor{
fs: st.fs,
path: st.path,
dtype: st.dtype,
offset: st.offset,
size: st.size,
tensorBase: &tensorBase{
name: st.name,
repacker: st.repacker,
shape: slices.Clone(st.shape),
},
}
}
func (st safetensor) WriteTo(w io.Writer) (int64, error) {
f, err := st.fs.Open(st.path)
if err != nil {

View File

@ -43,17 +43,6 @@ type torch struct {
*tensorBase
}
func (t torch) Clone() Tensor {
return torch{
storage: t.storage,
tensorBase: &tensorBase{
name: t.name,
shape: t.shape,
repacker: t.repacker,
},
}
}
func (pt torch) WriteTo(w io.Writer) (int64, error) {
return 0, nil
}

View File

@ -1,56 +0,0 @@
package convert
import (
"iter"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
)
// splitDim splits a tensor along a specified dimension into multiple tensors. The dimension
// is split evenly based on the number of replacers provided.
func splitDim(t Tensor, dim int, replacers ...*strings.Replacer) iter.Seq[*ggml.Tensor] {
return func(yield func(*ggml.Tensor) bool) {
for i, replacer := range replacers {
shape := slices.Clone(t.Shape())
shape[dim] = shape[dim] / uint64(len(replacers))
slice := slices.Repeat([]tensor.Slice{nil}, len(shape))
slice[dim] = tensor.S(i*int(shape[dim]), (i+1)*int(shape[dim]))
tt := t.Clone()
tt.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
t, err := t.Slice(slice...)
if err != nil {
return nil, err
}
t = tensor.Materialize(t)
// flatten tensor so it can be written as a vector
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(t.(*tensor.Dense))
})
if !yield(&ggml.Tensor{
Name: replacer.Replace(t.Name()),
Kind: t.Kind(),
Shape: shape,
WriterTo: tt,
}) {
break
}
}
}
}

View File

@ -670,7 +670,7 @@ func loadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string, e
}
func getVerboseState() C.uint16_t {
if envconfig.LogLevel() < slog.LevelInfo {
if envconfig.Debug() {
return C.uint16_t(1)
}
return C.uint16_t(0)

View File

@ -27,14 +27,12 @@
#endif
#ifndef LOG
#define LOG(verbose, ...) \
do { \
if (verbose) { \
fprintf(stderr, __VA_ARGS__); \
} \
} while (0)
#endif
#ifdef __cplusplus
extern "C" {

View File

@ -1,7 +1,6 @@
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
#include <string.h>
#include <inttypes.h>
#include "gpu_info_cudart.h"
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
@ -59,7 +58,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
LOG(resp->ch.verbose, "cudaSetDevice err: %d\n", ret);
UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL;
if (ret == CUDART_ERROR_INSUFFICIENT_DRIVER) {
if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama");
return;
}
@ -169,9 +168,9 @@ void cudart_bootstrap(cudart_handle_t h, int i, mem_info_t *resp) {
resp->free = memInfo.free;
resp->used = memInfo.used;
LOG(h.verbose, "[%s] CUDA totalMem %" PRId64 "\n", resp->gpu_id, resp->total);
LOG(h.verbose, "[%s] CUDA freeMem %" PRId64 "\n", resp->gpu_id, resp->free);
LOG(h.verbose, "[%s] CUDA usedMem %" PRId64 "\n", resp->gpu_id, resp->used);
LOG(h.verbose, "[%s] CUDA totalMem %lu\n", resp->gpu_id, resp->total);
LOG(h.verbose, "[%s] CUDA freeMem %lu\n", resp->gpu_id, resp->free);
LOG(h.verbose, "[%s] CUDA usedMem %lu\n", resp->gpu_id, resp->used);
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
}
@ -181,4 +180,4 @@ void cudart_release(cudart_handle_t h) {
h.handle = NULL;
}
#endif // __APPLE__
#endif // __APPLE__

View File

@ -1,7 +1,6 @@
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
#include <string.h>
#include <inttypes.h>
#include "gpu_info_nvcuda.h"
void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
@ -194,8 +193,8 @@ void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) {
resp->total = memInfo.total;
resp->free = memInfo.free;
LOG(h.verbose, "[%s] CUDA totalMem %" PRId64 "mb\n", resp->gpu_id, resp->total / 1024 / 1024);
LOG(h.verbose, "[%s] CUDA freeMem %" PRId64 "mb\n", resp->gpu_id, resp->free / 1024 / 1024);
LOG(h.verbose, "[%s] CUDA totalMem %lu mb\n", resp->gpu_id, resp->total / 1024 / 1024);
LOG(h.verbose, "[%s] CUDA freeMem %lu mb\n", resp->gpu_id, resp->free / 1024 / 1024);
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
@ -248,4 +247,4 @@ void nvcuda_release(nvcuda_handle_t h) {
h.handle = NULL;
}
#endif // __APPLE__
#endif // __APPLE__

View File

@ -19,7 +19,7 @@
### Model names
Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q8_0` and `llama3:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version.
Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q4_1` and `llama3:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version.
### Durations
@ -173,7 +173,7 @@ curl http://localhost:11434/api/generate -d '{
##### Response
```json5
```json
{
"model": "codellama:code",
"created_at": "2024-07-22T20:47:51.147561Z",
@ -394,6 +394,9 @@ curl http://localhost:11434/api/generate -d '{
"repeat_penalty": 1.2,
"presence_penalty": 1.5,
"frequency_penalty": 1.0,
"mirostat": 1,
"mirostat_tau": 0.8,
"mirostat_eta": 0.6,
"penalize_newline": true,
"stop": ["\n", "user:"],
"numa": false,
@ -401,7 +404,10 @@ curl http://localhost:11434/api/generate -d '{
"num_batch": 2,
"num_gpu": 1,
"main_gpu": 0,
"low_vram": false,
"vocab_only": false,
"use_mmap": true,
"use_mlock": false,
"num_thread": 8
}
}'
@ -952,8 +958,19 @@ If you are creating a model from a safetensors directory or from a GGUF file, yo
| Type | Recommended |
| --- | :-: |
| q2_K | |
| q3_K_L | |
| q3_K_M | |
| q3_K_S | |
| q4_0 | |
| q4_1 | |
| q4_K_M | * |
| q4_K_S | |
| q5_0 | |
| q5_1 | |
| q5_K_M | |
| q5_K_S | |
| q6_K | |
| q8_0 | * |
### Examples
@ -998,8 +1015,8 @@ Quantize a non-quantized model.
```shell
curl http://localhost:11434/api/create -d '{
"model": "llama3.2:quantized",
"from": "llama3.2:3b-instruct-fp16",
"model": "llama3.1:quantized",
"from": "llama3.1:8b-instruct-fp16",
"quantize": "q4_K_M"
}'
```
@ -1009,14 +1026,12 @@ curl http://localhost:11434/api/create -d '{
A stream of JSON objects is returned:
```json
{"status":"quantizing F16 model to Q4_K_M","digest":"0","total":6433687776,"completed":12302}
{"status":"quantizing F16 model to Q4_K_M","digest":"0","total":6433687776,"completed":6433687552}
{"status":"verifying conversion"}
{"status":"creating new layer sha256:fb7f4f211b89c6c4928ff4ddb73db9f9c0cfca3e000c3e40d6cf27ddc6ca72eb"}
{"status":"using existing layer sha256:966de95ca8a62200913e3f8bfbf84c8494536f1b94b49166851e76644e966396"}
{"status":"using existing layer sha256:fcc5a6bec9daf9b561a68827b67ab6088e1dba9d1fa2a50d7bbcc8384e0a265d"}
{"status":"using existing layer sha256:a70ff7e570d97baaf4e62ac6e6ad9975e04caa6d900d3742d37698494479e0cd"}
{"status":"quantizing F16 model to Q4_K_M"}
{"status":"creating new layer sha256:667b0c1932bc6ffc593ed1d03f895bf2dc8dc6df21db3042284a6f4416b06a29"}
{"status":"using existing layer sha256:11ce4ee3e170f6adebac9a991c22e22ab3f8530e154ee669954c4bc73061c258"}
{"status":"using existing layer sha256:0ba8f0e314b4264dfd19df045cde9d4c394a52474bf92ed6a3de22a4ca31a177"}
{"status":"using existing layer sha256:56bb8bd477a519ffa694fc449c2413c6f0e1d3b1c88fa7e3c9d88d3ae49d4dcb"}
{"status":"creating new layer sha256:455f34728c9b5dd3376378bfb809ee166c145b0b4c1f1a6feca069055066ef9a"}
{"status":"writing manifest"}
{"status":"success"}
```
@ -1154,37 +1169,29 @@ A single JSON object will be returned.
{
"models": [
{
"name": "deepseek-r1:latest",
"model": "deepseek-r1:latest",
"modified_at": "2025-05-10T08:06:48.639712648-07:00",
"size": 4683075271,
"digest": "0a8c266910232fd3291e71e5ba1e058cc5af9d411192cf88b6d30e92b6e73163",
"name": "codellama:13b",
"modified_at": "2023-11-04T14:56:49.277302595-07:00",
"size": 7365960935,
"digest": "9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697",
"details": {
"parent_model": "",
"format": "gguf",
"family": "qwen2",
"families": [
"qwen2"
],
"parameter_size": "7.6B",
"quantization_level": "Q4_K_M"
"family": "llama",
"families": null,
"parameter_size": "13B",
"quantization_level": "Q4_0"
}
},
{
"name": "llama3.2:latest",
"model": "llama3.2:latest",
"modified_at": "2025-05-04T17:37:44.706015396-07:00",
"size": 2019393189,
"digest": "a80c4f17acd55265feec403c7aef86be0c25983ab279d83f3bcd3abbcb5b8b72",
"name": "llama3:latest",
"modified_at": "2023-12-07T09:32:18.757212583-08:00",
"size": 3825819519,
"digest": "fe938a131f40e6f6d40083c9f0f430a515233eb2edaa6d72eb85c50d64f2300e",
"details": {
"parent_model": "",
"format": "gguf",
"family": "llama",
"families": [
"llama"
],
"parameter_size": "3.2B",
"quantization_level": "Q4_K_M"
"families": null,
"parameter_size": "7B",
"quantization_level": "Q4_0"
}
}
]
@ -1216,7 +1223,7 @@ curl http://localhost:11434/api/show -d '{
#### Response
```json5
```json
{
"modelfile": "# Modelfile generated by \"ollama show\"\n# To build a new Modelfile based on this one, replace the FROM line with:\n# FROM llava:latest\n\nFROM /Users/matt/.ollama/models/blobs/sha256:200765e1283640ffbd013184bf496e261032fa75b99498a9613be4e94d63ad52\nTEMPLATE \"\"\"{{ .System }}\nUSER: {{ .Prompt }}\nASSISTANT: \"\"\"\nPARAMETER num_ctx 4096\nPARAMETER stop \"\u003c/s\u003e\"\nPARAMETER stop \"USER:\"\nPARAMETER stop \"ASSISTANT:\"",
"parameters": "num_keep 24\nstop \"<|start_header_id|>\"\nstop \"<|end_header_id|>\"\nstop \"<|eot_id|>\"",

View File

@ -20,7 +20,7 @@ Please refer to the [GPU docs](./gpu.md).
## How can I specify the context window size?
By default, Ollama uses a context window size of 4096 tokens.
By default, Ollama uses a context window size of 2048 tokens.
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:

View File

@ -150,6 +150,9 @@ PARAMETER <parameter> <parametervalue>
| Parameter | Description | Value Type | Example Usage |
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
| mirostat | Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | int | mirostat 0 |
| mirostat_eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | float | mirostat_eta 0.1 |
| mirostat_tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | float | mirostat_tau 5.0 |
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |

View File

@ -12,7 +12,7 @@ A basic Go template consists of three main parts:
Here's an example of a simple chat template:
```go
```gotmpl
{{- range .Messages }}
{{ .Role }}: {{ .Content }}
{{- end }}
@ -162,6 +162,6 @@ CodeLlama [7B](https://ollama.com/library/codellama:7b-code) and [13B](https://o
Codestral [22B](https://ollama.com/library/codestral:22b) supports fill-in-middle.
```go
```gotmpl
[SUFFIX]{{ .Suffix }}[PREFIX] {{ .Prompt }}
```

View File

@ -26,6 +26,7 @@ When you run Ollama on **Windows**, there are a few different locations. You can
- `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log`
- `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH)
- `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored
- `explorer %TEMP%` where temporary executable files are stored in one or more `ollama*` directories
To enable additional debug logging to help troubleshoot problems, first **Quit the running app from the tray menu** then in a powershell terminal
@ -68,6 +69,10 @@ If you run into problems on Linux and want to install an older version, or you'd
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.5.7 sh
```
## Linux tmp noexec
If your system is configured with the "noexec" flag where Ollama stores its temporary executable files, you can specify an alternate location by setting OLLAMA_TMPDIR to a location writable by the user ollama runs as. For example OLLAMA_TMPDIR=/usr/share/ollama/
## Linux docker
If Ollama initially works on the GPU in a docker container, but then switches to running on CPU after some period of time with errors in the server log reporting GPU discovery failures, this can be resolved by disabling systemd cgroup management in Docker. Edit `/etc/docker/daemon.json` on the host and add `"exec-opts": ["native.cgroupdriver=cgroupfs"]` to the docker configuration.

View File

@ -62,6 +62,7 @@ the explorer window by hitting `<Ctrl>+R` and type in:
- *upgrade.log* contains log output for upgrades
- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
- `explorer %HOMEPATH%\.ollama` contains models and configuration
- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories
## Uninstall

View File

@ -149,22 +149,9 @@ func Bool(k string) func() bool {
}
}
// LogLevel returns the log level for the application.
// Values are 0 or false INFO (Default), 1 or true DEBUG, 2 TRACE
func LogLevel() slog.Level {
level := slog.LevelInfo
if s := Var("OLLAMA_DEBUG"); s != "" {
if b, _ := strconv.ParseBool(s); b {
level = slog.LevelDebug
} else if i, _ := strconv.ParseInt(s, 10, 64); i != 0 {
level = slog.Level(i * -4)
}
}
return level
}
var (
// Debug enabled additional debug information.
Debug = Bool("OLLAMA_DEBUG")
// FlashAttention enables the experimental flash attention feature.
FlashAttention = Bool("OLLAMA_FLASH_ATTENTION")
// KvCacheType is the quantization type for the K/V cache.
@ -182,7 +169,7 @@ var (
// Enable the new Ollama engine
NewEngine = Bool("OLLAMA_NEW_ENGINE")
// ContextLength sets the default context length
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 2048)
)
func String(s string) func() string {
@ -222,6 +209,8 @@ var (
MaxRunners = Uint("OLLAMA_MAX_LOADED_MODELS", 0)
// MaxQueue sets the maximum number of queued requests. MaxQueue can be configured via the OLLAMA_MAX_QUEUE environment variable.
MaxQueue = Uint("OLLAMA_MAX_QUEUE", 512)
// MaxVRAM sets a maximum VRAM override in bytes. MaxVRAM can be configured via the OLLAMA_MAX_VRAM environment variable.
MaxVRAM = Uint("OLLAMA_MAX_VRAM", 0)
)
func Uint64(key string, defaultValue uint64) func() uint64 {
@ -249,7 +238,7 @@ type EnvVar struct {
func AsMap() map[string]EnvVar {
ret := map[string]EnvVar{
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"},
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
@ -266,7 +255,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 2048)"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
// Informational

View File

@ -1,13 +1,11 @@
package envconfig
import (
"log/slog"
"math"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/logutil"
)
func TestHost(t *testing.T) {
@ -281,8 +279,8 @@ func TestVar(t *testing.T) {
func TestContextLength(t *testing.T) {
cases := map[string]uint{
"": 4096,
"2048": 2048,
"": 2048,
"4096": 4096,
}
for k, v := range cases {
@ -294,34 +292,3 @@ func TestContextLength(t *testing.T) {
})
}
}
func TestLogLevel(t *testing.T) {
cases := map[string]slog.Level{
// Default to INFO
"": slog.LevelInfo,
"false": slog.LevelInfo,
"f": slog.LevelInfo,
"0": slog.LevelInfo,
// True values enable Debug
"true": slog.LevelDebug,
"t": slog.LevelDebug,
// Positive values increase verbosity
"1": slog.LevelDebug,
"2": logutil.LevelTrace,
// Negative values decrease verbosity
"-1": slog.LevelWarn,
"-2": slog.LevelError,
}
for k, v := range cases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_DEBUG", k)
if i := LogLevel(); i != v {
t.Errorf("%s: expected %d, got %d", k, v, i)
}
})
}
}

View File

@ -8,6 +8,6 @@ type Config interface {
Bool(string, ...bool) bool
Strings(string, ...[]string) []string
Ints(string, ...[]int32) []int32
Uints(string, ...[]uint32) []uint32
Floats(string, ...[]float32) []float32
}

View File

@ -6,7 +6,6 @@ import (
"fmt"
"io"
"log/slog"
"math"
"slices"
"strings"
@ -34,15 +33,15 @@ func (kv KV) Kind() string {
}
func (kv KV) ParameterCount() uint64 {
return keyValue(kv, "general.parameter_count", uint64(0))
return keyValue[uint64](kv, "general.parameter_count")
}
func (kv KV) FileType() FileType {
func (kv KV) FileType() fileType {
if t := kv.Uint("general.file_type"); t > 0 {
return FileType(t)
return fileType(t)
}
return FileTypeUnknown
return fileTypeUnknown
}
func (kv KV) BlockCount() uint64 {
@ -106,44 +105,42 @@ func (kv KV) Bool(key string, defaultValue ...bool) bool {
}
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
return keyValue(kv, key, &array[string]{values: append(defaultValue, []string(nil))[0]}).values
}
r := keyValue(kv, key, &array{})
s := make([]string, r.size)
for i := range r.size {
s[i] = r.values[i].(string)
}
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
return keyValue(kv, key, &array[int32]{values: append(defaultValue, []int32(nil))[0]}).values
return s
}
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
return keyValue(kv, key, &array[uint32]{values: append(defaultValue, []uint32(nil))[0]}).values
r := keyValue(kv, key, &array{})
s := make([]uint32, r.size)
for i := range r.size {
s[i] = uint32(r.values[i].(int32))
}
return s
}
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
return keyValue(kv, key, &array[float32]{values: append(defaultValue, []float32(nil))[0]}).values
r := keyValue(kv, key, &array{})
s := make([]float32, r.size)
for i := range r.size {
s[i] = float32(r.values[i].(float32))
}
return s
}
func (kv KV) OllamaEngineRequired() bool {
return slices.Contains([]string{
"gemma3",
"mistral3",
"llama4",
"mllama",
"qwen25vl",
}, kv.Architecture())
}
type valueTypes interface {
uint8 | int8 | uint16 | int16 |
uint32 | int32 | uint64 | int64 |
string | float32 | float64 | bool
}
type arrayValueTypes interface {
*array[uint8] | *array[int8] | *array[uint16] | *array[int16] |
*array[uint32] | *array[int32] | *array[uint64] | *array[int64] |
*array[string] | *array[float32] | *array[float64] | *array[bool]
}
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) T {
func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key
}
@ -152,7 +149,7 @@ func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ..
return val.(T)
}
slog.Debug("key not found", "key", key, "default", defaultValue[0])
slog.Warn("key not found", "key", key, "default", defaultValue[0])
return defaultValue[0]
}
@ -229,11 +226,7 @@ func (t Tensor) block() (n int) {
}
func (t Tensor) blockSize() uint64 {
return (TensorType)(t.Kind).BlockSize()
}
func (t TensorType) BlockSize() uint64 {
switch t {
switch t.Kind {
case
0, // F32
1, // F16
@ -259,77 +252,73 @@ func (t TensorType) BlockSize() uint64 {
}
func (t Tensor) typeSize() uint64 {
return TensorType(t.Kind).TypeSize()
}
blockSize := t.blockSize()
func (t TensorType) TypeSize() uint64 {
blockSize := t.BlockSize()
switch t {
case TensorTypeF32:
switch t.Kind {
case 0: // FP32
return 4
case TensorTypeF16:
case 1: // FP16
return 2
case TensorTypeQ4_0:
case 2: // Q4_0
return 2 + blockSize/2
case TensorTypeQ4_1:
case 3: // Q4_1
return 2 + 2 + blockSize/2
case TensorTypeQ5_0:
case 6: // Q5_0
return 2 + 4 + blockSize/2
case TensorTypeQ5_1:
case 7: // Q5_1
return 2 + 2 + 4 + blockSize/2
case TensorTypeQ8_0:
case 8: // Q8_0
return 2 + blockSize
case TensorTypeQ8_1:
case 9: // Q8_1
return 2 + 2 + blockSize
case TensorTypeQ2_K:
case 10: // Q2_K
return blockSize/16 + blockSize/4 + 2 + 2
case TensorTypeQ3_K:
case 11: // Q3_K
return blockSize/8 + blockSize/4 + 12 + 2
case TensorTypeQ4_K:
case 12: // Q4_K
return 2 + 2 + 12 + blockSize/2
case TensorTypeQ5_K:
case 13: // Q5_K
return 2 + 2 + 12 + blockSize/8 + blockSize/2
case TensorTypeQ6_K:
case 14: // Q6_K
return blockSize/2 + blockSize/4 + blockSize/16 + 2
case TensorTypeQ8_K:
case 15: // Q8_K
return 4 + blockSize + 2*blockSize/16
case tensorTypeIQ2_XXS:
case 16: // IQ2_XXS
return 2 + 2*blockSize/8
case tensorTypeIQ2_XS:
case 17: // IQ2_XS
return 2 + 2*blockSize/8 + blockSize/32
case tensorTypeIQ3_XXS:
case 18: // IQ3_XXS
return 2 + blockSize/4 + blockSize/8
case tensorTypeIQ1_S:
case 19: // IQ1_S
return 2 + blockSize/8 + blockSize/16
case tensorTypeIQ4_NL:
case 20: // IQ4_NL
return 2 + blockSize/2
case tensorTypeIQ3_S:
case 21: // IQ3_S
return 2 + blockSize/4 + blockSize/8 + blockSize/32 + 4
case tensorTypeIQ2_S:
case 22: // IQ2_S
return 2 + blockSize/4 + blockSize/16
case tensorTypeIQ4_XS:
case 23: // IQ4_XS
return 2 + 2 + blockSize/2 + blockSize/64
case TensorTypeI8:
case 24: // I8
return 1
case TensorTypeI16:
case 25: // I16
return 2
case TensorTypeI32:
case 26: // I32
return 4
case TensorTypeI64:
case 27: // I64
return 8
case TensorTypeF64:
case 28: // F64
return 8
case tensorTypeIQ1_M:
case 29: // IQ1_M
return blockSize/8 + blockSize/16 + blockSize/32
case TensorTypeBF16:
case 30: // BF16
return 2
default:
return 0
}
}
func (t Tensor) Elements() uint64 {
func (t Tensor) parameters() uint64 {
var count uint64 = 1
for _, n := range t.Shape {
count *= n
@ -338,11 +327,11 @@ func (t Tensor) Elements() uint64 {
}
func (t Tensor) Size() uint64 {
return t.Elements() * t.typeSize() / t.blockSize()
return t.parameters() * t.typeSize() / t.blockSize()
}
func (t Tensor) Type() string {
return TensorType(t.Kind).String()
return fileType(t.Kind).String()
}
type container interface {
@ -386,8 +375,13 @@ func DetectContentType(b []byte) string {
// Decode decodes a GGML model from the given reader.
//
// It collects array values for arrays with a size less than or equal to
// maxArraySize. If the maxArraySize is negative, all arrays are collected.
// maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
// the maxArraySize is negative, all arrays are collected.
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
if maxArraySize == 0 {
maxArraySize = 1024
}
rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
var magic uint32
@ -426,7 +420,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
embedding := f.KV().EmbeddingLength()
heads := f.KV().HeadCount()
headsKV := f.KV().HeadCountKV()
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size)
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array).size)
embeddingHeads := f.KV().EmbeddingHeadCount()
embeddingHeadsK := f.KV().EmbeddingHeadCountK()
@ -441,7 +435,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
}
switch f.KV().Architecture() {
case "llama", "llama4":
case "llama":
fullOffload = max(
4*batch*(1+4*embedding+context*(1+heads)),
4*batch*(embedding+vocab),
@ -455,7 +449,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
// mixtral 8x22b
ff := uint64(f.KV().Uint("feed_forward_length"))
ff := uint64(f.KV()["llama.feed_forward_length"].(uint32))
partialOffload = max(
3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
@ -472,9 +466,9 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
case "mllama":
var visionTokens, tiles uint64 = 1601, 4
crossAttentionLayers := f.KV().Ints("attention.cross_attention_layers")
crossAttentionLayers := f.KV().Uints("attention.cross_attention_layers")
for i := range kv {
if slices.Contains(crossAttentionLayers, int32(i)) {
if slices.Contains(crossAttentionLayers, uint32(i)) {
kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) *
4 * // sizeof(float32)
visionTokens *
@ -491,7 +485,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
var ropeFreqsCount uint64
if ropeFreqs, ok := f.Tensors().GroupLayers()["rope_freqs"]; ok {
if ropeFreqsWeights, ok := ropeFreqs["weights"]; ok {
ropeFreqsCount = ropeFreqsWeights.Elements()
ropeFreqsCount = ropeFreqsWeights.parameters()
}
}
@ -651,32 +645,6 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
graphSize = 4 * (imageSize*imageSize*numChannels +
embeddingLength*patchSize +
numPatches*numPatches*headCount)
case "qwen25vl":
maxPixels := uint64(llm.KV().Uint("vision.max_pixels", 28*28*1280))
mergeSize := uint64(llm.KV().Uint("vision.spatial_merge_size", 2))
temporalPatchSize := uint64(2)
// Calculate max possible patches based on max_pixels
maxHeight := uint64(math.Sqrt(float64(maxPixels)))
maxWidth := maxPixels / maxHeight
maxGridHeight := maxHeight / patchSize
maxGridWidth := maxWidth / patchSize
// Account for merged patches (2x2 grid)
numPatches := (maxGridHeight * maxGridWidth) / (mergeSize * mergeSize)
// Calculate graph size based on typical operations in ProcessImage and createPatches
graphSize = 4 * (maxPixels*numChannels + // Original image storage
// Normalized pixels
maxPixels*numChannels +
// Patches storage (numPatches * channels * temporalPatchSize * patchSize^2)
numPatches*numChannels*temporalPatchSize*patchSize*patchSize +
// Self-attention calculations (similar to other architectures)
numPatches*numPatches*headCount +
// Additional buffer for processing
embeddingLength*numPatches)
case "llama4":
// vision graph is computed independently in the same schedule
// and is negligible compared to the worst case text graph
}
return weights, graphSize

View File

@ -2,7 +2,6 @@ package ggml
import (
"maps"
"math"
"slices"
"strconv"
"strings"
@ -211,61 +210,3 @@ func TestTensorTypes(t *testing.T) {
})
}
}
func TestKeyValue(t *testing.T) {
kv := KV{
"general.architecture": "test",
"test.strings": &array[string]{size: 3, values: []string{"a", "b", "c"}},
"test.float32s": &array[float32]{size: 3, values: []float32{1.0, 2.0, 3.0}},
"test.int32s": &array[int32]{size: 3, values: []int32{1, 2, 3}},
"test.uint32s": &array[uint32]{size: 3, values: []uint32{1, 2, 3}},
}
if diff := cmp.Diff(kv.Strings("strings"), []string{"a", "b", "c"}); diff != "" {
t.Errorf("unexpected strings (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Strings("nonexistent.strings"), []string(nil)); diff != "" {
t.Errorf("unexpected strings (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Strings("default.strings", []string{"ollama"}), []string{"ollama"}); diff != "" {
t.Errorf("unexpected strings (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Floats("float32s"), []float32{1.0, 2.0, 3.0}); diff != "" {
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Floats("nonexistent.float32s"), []float32(nil)); diff != "" {
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Floats("default.float32s", []float32{math.MaxFloat32}), []float32{math.MaxFloat32}); diff != "" {
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Ints("int32s"), []int32{1, 2, 3}); diff != "" {
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Ints("nonexistent.int32s"), []int32(nil)); diff != "" {
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Ints("default.int32s", []int32{math.MaxInt32}), []int32{math.MaxInt32}); diff != "" {
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Uints("uint32s"), []uint32{1, 2, 3}); diff != "" {
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Uints("nonexistent.uint32s"), []uint32(nil)); diff != "" {
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Uints("default.uint32s", []uint32{math.MaxUint32}), []uint32{math.MaxUint32}); diff != "" {
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
}
}

View File

@ -9,12 +9,8 @@ import (
"io"
"log/slog"
"maps"
"os"
"runtime"
"slices"
"strings"
"golang.org/x/sync/errgroup"
)
type containerGGUF struct {
@ -40,6 +36,10 @@ type containerGGUF struct {
maxArraySize int
}
func (c *containerGGUF) canCollectArray(size int) bool {
return c.maxArraySize < 0 || size <= c.maxArraySize
}
func (c *containerGGUF) Name() string {
return "gguf"
}
@ -229,13 +229,16 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
}
llm.tensors = append(llm.tensors, &tensor)
llm.parameters += tensor.Elements()
llm.parameters += tensor.parameters()
}
// patch KV with parameter count
llm.kv["general.parameter_count"] = llm.parameters
alignment := llm.kv.Uint("general.alignment", 32)
alignment, ok := llm.kv["general.alignment"].(uint32)
if !ok {
alignment = 32
}
offset, err := rs.Seek(0, io.SeekCurrent)
if err != nil {
@ -295,23 +298,6 @@ func readGGUFV1String(llm *gguf, r io.Reader) (string, error) {
return b.String(), nil
}
func readGGUFV1StringsData(llm *gguf, r io.Reader, a *array[string]) (any, error) {
for i := range a.size {
if a.values != nil {
e, err := readGGUFV1String(llm, r)
if err != nil {
return nil, err
}
a.values[i] = e
} else {
discardGGUFString(llm, r)
}
}
return a, nil
}
func discardGGUFString(llm *gguf, r io.Reader) error {
buf := llm.scratch[:8]
_, err := io.ReadFull(r, buf)
@ -369,44 +355,78 @@ func writeGGUFString(w io.Writer, s string) error {
return err
}
func readGGUFStringsData(llm *gguf, r io.Reader, a *array[string]) (any, error) {
for i := range a.size {
if a.values != nil {
e, err := readGGUFString(llm, r)
if err != nil {
return nil, err
}
type array struct {
size int
values []any
}
func (a *array) MarshalJSON() ([]byte, error) {
return json.Marshal(a.values)
}
func readGGUFV1Array(llm *gguf, r io.Reader) (*array, error) {
t, err := readGGUF[uint32](llm, r)
if err != nil {
return nil, err
}
n, err := readGGUF[uint32](llm, r)
if err != nil {
return nil, err
}
a := &array{size: int(n)}
if llm.canCollectArray(int(n)) {
a.values = make([]any, 0, int(n))
}
for i := range n {
var e any
switch t {
case ggufTypeUint8:
e, err = readGGUF[uint8](llm, r)
case ggufTypeInt8:
e, err = readGGUF[int8](llm, r)
case ggufTypeUint16:
e, err = readGGUF[uint16](llm, r)
case ggufTypeInt16:
e, err = readGGUF[int16](llm, r)
case ggufTypeUint32:
e, err = readGGUF[uint32](llm, r)
case ggufTypeInt32:
e, err = readGGUF[int32](llm, r)
case ggufTypeUint64:
e, err = readGGUF[uint64](llm, r)
case ggufTypeInt64:
e, err = readGGUF[int64](llm, r)
case ggufTypeFloat32:
e, err = readGGUF[float32](llm, r)
case ggufTypeFloat64:
e, err = readGGUF[float64](llm, r)
case ggufTypeBool:
e, err = readGGUF[bool](llm, r)
case ggufTypeString:
e, err = readGGUFV1String(llm, r)
default:
return nil, fmt.Errorf("invalid array type: %d", t)
}
if err != nil {
return nil, err
}
if a.values != nil {
a.values[i] = e
} else {
discardGGUFString(llm, r)
}
}
return a, nil
}
type array[T any] struct {
// size is the actual size of the array
size int
// values is the array of values. this is nil if the array is larger than configured maxSize
values []T
}
func (a *array[T]) MarshalJSON() ([]byte, error) {
return json.Marshal(a.values)
}
func newArray[T any](size, maxSize int) *array[T] {
a := array[T]{size: size}
if maxSize < 0 || size <= maxSize {
a.values = make([]T, size)
func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
if llm.Version == 1 {
return readGGUFV1Array(llm, r)
}
return &a
}
func readGGUFArray(llm *gguf, r io.Reader) (any, error) {
t, err := readGGUF[uint32](llm, r)
if err != nil {
return nil, err
@ -417,55 +437,45 @@ func readGGUFArray(llm *gguf, r io.Reader) (any, error) {
return nil, err
}
switch t {
case ggufTypeUint8:
a := newArray[uint8](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeInt8:
a := newArray[int8](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeUint16:
a := newArray[uint16](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeInt16:
a := newArray[int16](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeUint32:
a := newArray[uint32](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeInt32:
a := newArray[int32](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeUint64:
a := newArray[uint64](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeInt64:
a := newArray[int64](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeFloat32:
a := newArray[float32](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeFloat64:
a := newArray[float64](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeBool:
a := newArray[bool](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeString:
a := newArray[string](int(n), llm.maxArraySize)
if llm.Version == 1 {
return readGGUFV1StringsData(llm, r, a)
}
return readGGUFStringsData(llm, r, a)
default:
return nil, fmt.Errorf("invalid array type: %d", t)
a := &array{size: int(n)}
if llm.canCollectArray(int(n)) {
a.values = make([]any, int(n))
}
}
func readGGUFArrayData[T any](llm *gguf, r io.Reader, a *array[T]) (any, error) {
for i := range a.size {
e, err := readGGUF[T](llm, r)
for i := range n {
var e any
switch t {
case ggufTypeUint8:
e, err = readGGUF[uint8](llm, r)
case ggufTypeInt8:
e, err = readGGUF[int8](llm, r)
case ggufTypeUint16:
e, err = readGGUF[uint16](llm, r)
case ggufTypeInt16:
e, err = readGGUF[int16](llm, r)
case ggufTypeUint32:
e, err = readGGUF[uint32](llm, r)
case ggufTypeInt32:
e, err = readGGUF[int32](llm, r)
case ggufTypeUint64:
e, err = readGGUF[uint64](llm, r)
case ggufTypeInt64:
e, err = readGGUF[int64](llm, r)
case ggufTypeFloat32:
e, err = readGGUF[float32](llm, r)
case ggufTypeFloat64:
e, err = readGGUF[float64](llm, r)
case ggufTypeBool:
e, err = readGGUF[bool](llm, r)
case ggufTypeString:
if a.values != nil {
e, err = readGGUFString(llm, r)
} else {
err = discardGGUFString(llm, r)
}
default:
return nil, fmt.Errorf("invalid array type: %d", t)
}
if err != nil {
return nil, err
}
@ -492,38 +502,23 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
return err
}
if t == ggufTypeString {
for _, e := range any(s).([]string) {
if err := binary.Write(w, binary.LittleEndian, uint64(len(e))); err != nil {
return err
}
if err := binary.Write(w, binary.LittleEndian, []byte(e)); err != nil {
return err
}
}
return nil
}
return binary.Write(w, binary.LittleEndian, s)
}
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
alignment := kv.Uint("general.alignment", 32)
if err := binary.Write(f, binary.LittleEndian, []byte("GGUF")); err != nil {
func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
if err := binary.Write(ws, binary.LittleEndian, []byte("GGUF")); err != nil {
return err
}
if err := binary.Write(f, binary.LittleEndian, uint32(3)); err != nil {
if err := binary.Write(ws, binary.LittleEndian, uint32(3)); err != nil {
return err
}
if err := binary.Write(f, binary.LittleEndian, uint64(len(ts))); err != nil {
if err := binary.Write(ws, binary.LittleEndian, uint64(len(ts))); err != nil {
return err
}
if err := binary.Write(f, binary.LittleEndian, uint64(len(kv))); err != nil {
if err := binary.Write(ws, binary.LittleEndian, uint64(len(kv))); err != nil {
return err
}
@ -531,12 +526,12 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
slices.Sort(keys)
for _, key := range keys {
if err := ggufWriteKV(f, key, kv[key]); err != nil {
if err := ggufWriteKV(ws, key, kv[key]); err != nil {
return err
}
}
slices.SortStableFunc(ts, func(a, b *Tensor) int {
slices.SortStableFunc(ts, func(a, b Tensor) int {
if i, j := a.block(), b.block(); i < 0 && j > 0 {
return 1
} else if i > 0 && j < 0 {
@ -547,34 +542,22 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
})
var s uint64
for i := range ts {
ts[i].Offset = s
if err := ggufWriteTensorInfo(f, ts[i]); err != nil {
for _, t := range ts {
t.Offset = s
if err := ggufWriteTensorInfo(ws, t); err != nil {
return err
}
s += ts[i].Size()
s += uint64(ggufPadding(int64(s), int64(alignment)))
s += t.Size()
}
offset, err := f.Seek(0, io.SeekCurrent)
if err != nil {
return err
}
offset += ggufPadding(offset, int64(alignment))
var g errgroup.Group
g.SetLimit(runtime.GOMAXPROCS(0))
// TODO consider reducing if tensors size * gomaxprocs is larger than free memory
var alignment int64 = 32
for _, t := range ts {
t := t
w := io.NewOffsetWriter(f, offset+int64(t.Offset))
g.Go(func() error {
_, err := t.WriteTo(w)
if err := ggufWriteTensor(ws, t, alignment); err != nil {
return err
})
}
}
return g.Wait()
return nil
}
func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
@ -589,10 +572,8 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
var err error
switch v := v.(type) {
case uint32, FileType:
case uint32:
err = writeGGUF(ws, ggufTypeUint32, v)
case uint64:
err = writeGGUF(ws, ggufTypeUint64, v)
case float32:
err = writeGGUF(ws, ggufTypeFloat32, v)
case bool:
@ -601,20 +582,32 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
err = writeGGUFString(ws, v)
case []int32:
err = writeGGUFArray(ws, ggufTypeInt32, v)
case *array[int32]:
err = writeGGUFArray(ws, ggufTypeInt32, v.values)
case []uint32:
err = writeGGUFArray(ws, ggufTypeUint32, v)
case *array[uint32]:
err = writeGGUFArray(ws, ggufTypeUint32, v.values)
case []float32:
err = writeGGUFArray(ws, ggufTypeFloat32, v)
case *array[float32]:
err = writeGGUFArray(ws, ggufTypeFloat32, v.values)
case []string:
err = writeGGUFArray(ws, ggufTypeString, v)
case *array[string]:
err = writeGGUFArray(ws, ggufTypeString, v.values)
if err := binary.Write(ws, binary.LittleEndian, ggufTypeArray); err != nil {
return err
}
if err := binary.Write(ws, binary.LittleEndian, ggufTypeString); err != nil {
return err
}
if err := binary.Write(ws, binary.LittleEndian, uint64(len(v))); err != nil {
return err
}
for _, e := range v {
if err := binary.Write(ws, binary.LittleEndian, uint64(len(e))); err != nil {
return err
}
if err := binary.Write(ws, binary.LittleEndian, []byte(e)); err != nil {
return err
}
}
default:
return fmt.Errorf("improper type for '%s'", k)
}
@ -622,7 +615,7 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
return err
}
func ggufWriteTensorInfo(ws io.WriteSeeker, t *Tensor) error {
func ggufWriteTensorInfo(ws io.WriteSeeker, t Tensor) error {
slog.Debug(t.Name, "kind", t.Kind, "shape", t.Shape, "offset", t.Offset)
if err := binary.Write(ws, binary.LittleEndian, uint64(len(t.Name))); err != nil {
return err
@ -636,8 +629,8 @@ func ggufWriteTensorInfo(ws io.WriteSeeker, t *Tensor) error {
return err
}
for _, n := range t.Shape {
if err := binary.Write(ws, binary.LittleEndian, n); err != nil {
for i := range len(t.Shape) {
if err := binary.Write(ws, binary.LittleEndian, t.Shape[len(t.Shape)-i-1]); err != nil {
return err
}
}
@ -649,6 +642,20 @@ func ggufWriteTensorInfo(ws io.WriteSeeker, t *Tensor) error {
return binary.Write(ws, binary.LittleEndian, t.Offset)
}
func ggufWriteTensor(ws io.WriteSeeker, t Tensor, alignment int64) error {
offset, err := ws.Seek(0, io.SeekCurrent)
if err != nil {
return err
}
if err := binary.Write(ws, binary.LittleEndian, bytes.Repeat([]byte{0}, int(ggufPadding(offset, alignment)))); err != nil {
return err
}
_, err = t.WriteTo(ws)
return err
}
func ggufPadding(offset, align int64) int64 {
return (align - offset%align) % align
}

View File

@ -1,63 +0,0 @@
package ggml
import (
"bytes"
"os"
"slices"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestWriteGGUF(t *testing.T) {
w, err := os.CreateTemp(t.TempDir(), "*.bin")
if err != nil {
t.Fatal(err)
}
defer w.Close()
if err := WriteGGUF(w, KV{
"general.alignment": uint32(16),
}, []*Tensor{
{Name: "test.0", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.1", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.2", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.3", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.4", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.5", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
}); err != nil {
t.Fatal(err)
}
r, err := os.Open(w.Name())
if err != nil {
t.Fatal(err)
}
defer r.Close()
ff, _, err := Decode(r, 0)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(ff.KV(), KV{
"general.alignment": uint32(16),
"general.parameter_count": uint64(36),
}); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(ff.Tensors(), Tensors{
Offset: 336,
items: []*Tensor{
{Name: "test.0", Offset: 0, Shape: []uint64{2, 3}},
{Name: "test.1", Offset: 32, Shape: []uint64{2, 3}},
{Name: "test.2", Offset: 64, Shape: []uint64{2, 3}},
{Name: "test.3", Offset: 96, Shape: []uint64{2, 3}},
{Name: "test.4", Offset: 128, Shape: []uint64{2, 3}},
{Name: "test.5", Offset: 160, Shape: []uint64{2, 3}},
},
}, cmp.AllowUnexported(Tensors{})); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff)
}
}

View File

@ -1,31 +1,26 @@
package ggml
import (
"fmt"
"log/slog"
"strings"
)
import "fmt"
// FileType is the Go equivalent to llama_ftype used for gguf file typing
type FileType uint32
type fileType uint32
const (
FileTypeF32 FileType = iota
FileTypeF16
fileTypeF32 fileType = iota
fileTypeF16
fileTypeQ4_0
fileTypeQ4_1
fileTypeQ4_1_F16 // unused by GGML
fileTypeQ4_2 // unused by GGML
fileTypeQ4_3 // unused by GGML
FileTypeQ8_0
fileTypeQ4_1_F16
fileTypeQ4_2 // unused
fileTypeQ4_3 // unused
fileTypeQ8_0
fileTypeQ5_0
fileTypeQ5_1
fileTypeQ2_K
fileTypeQ3_K_S
fileTypeQ3_K_M
fileTypeQ3_K_L
FileTypeQ4_K_S
FileTypeQ4_K_M
fileTypeQ4_K_S
fileTypeQ4_K_M
fileTypeQ5_K_S
fileTypeQ5_K_M
fileTypeQ6_K
@ -42,62 +37,93 @@ const (
fileTypeIQ2_M
fileTypeIQ4_XS
fileTypeIQ1_M
FileTypeBF16
fileTypeQ4_0_4_4 // unused by GGML
fileTypeQ4_0_4_8 // unused by GGML
fileTypeQ4_0_8_8 // unused by GGML
fileTypeTQ1_0
fileTypeTQ2_0
fileTypeBF16
FileTypeUnknown = 1024
fileTypeUnknown
)
// ParseFileType parses the provided GGUF file type
// Only Ollama supported types are considered valid
func ParseFileType(s string) (FileType, error) {
func ParseFileType(s string) (fileType, error) {
switch s {
case "F32":
return FileTypeF32, nil
return fileTypeF32, nil
case "F16":
return FileTypeF16, nil
return fileTypeF16, nil
case "Q4_0":
return fileTypeQ4_0, nil
case "Q4_1":
return fileTypeQ4_1, nil
case "Q4_1_F16":
return fileTypeQ4_1_F16, nil
case "Q8_0":
return FileTypeQ8_0, nil
return fileTypeQ8_0, nil
case "Q5_0":
return fileTypeQ5_0, nil
case "Q5_1":
return fileTypeQ5_1, nil
case "Q2_K":
return fileTypeQ2_K, nil
case "Q3_K_S":
return fileTypeQ3_K_S, nil
case "Q3_K_M":
return fileTypeQ3_K_M, nil
case "Q3_K_L":
return fileTypeQ3_K_L, nil
case "Q4_K_S":
return FileTypeQ4_K_S, nil
case "Q4_K_M", "Q4_K":
return FileTypeQ4_K_M, nil
return fileTypeQ4_K_S, nil
case "Q4_K_M":
return fileTypeQ4_K_M, nil
case "Q5_K_S":
return fileTypeQ5_K_S, nil
case "Q5_K_M":
return fileTypeQ5_K_M, nil
case "Q6_K":
return fileTypeQ6_K, nil
case "IQ2_XXS":
return fileTypeIQ2_XXS, nil
case "IQ2_XS":
return fileTypeIQ2_XS, nil
case "Q2_K_S":
return fileTypeQ2_K_S, nil
case "IQ3_XS":
return fileTypeIQ3_XS, nil
case "IQ3_XXS":
return fileTypeIQ3_XXS, nil
case "IQ1_S":
return fileTypeIQ1_S, nil
case "IQ4_NL":
return fileTypeIQ4_NL, nil
case "IQ3_S":
return fileTypeIQ3_S, nil
case "IQ3_M":
return fileTypeIQ3_M, nil
case "IQ2_S":
return fileTypeIQ2_S, nil
case "IQ2_M":
return fileTypeIQ2_M, nil
case "IQ4_XS":
return fileTypeIQ4_XS, nil
case "IQ1_M":
return fileTypeIQ1_M, nil
case "BF16":
return FileTypeBF16, nil
return fileTypeBF16, nil
default:
supportedFileTypes := []FileType{
FileTypeF32,
FileTypeF16,
FileTypeQ4_K_S,
FileTypeQ4_K_M,
FileTypeQ8_0,
// fsggml.FileTypeBF16, // TODO
}
strs := make([]string, len(supportedFileTypes))
for i := range supportedFileTypes {
strs[i] = supportedFileTypes[i].String()
}
return FileTypeUnknown, fmt.Errorf("unsupported quantization type %s - supported types are %s", s, strings.Join(strs, ", "))
return fileTypeUnknown, fmt.Errorf("unknown fileType: %s", s)
}
}
func (t FileType) String() string {
// Note: this routine will return a broader set of file types for existing models
func (t fileType) String() string {
switch t {
case FileTypeF32:
case fileTypeF32:
return "F32"
case FileTypeF16:
case fileTypeF16:
return "F16"
case fileTypeQ4_0:
return "Q4_0"
case fileTypeQ4_1:
return "Q4_1"
case FileTypeQ8_0:
case fileTypeQ4_1_F16:
return "Q4_1_F16"
case fileTypeQ8_0:
return "Q8_0"
case fileTypeQ5_0:
return "Q5_0"
@ -111,9 +137,9 @@ func (t FileType) String() string {
return "Q3_K_M"
case fileTypeQ3_K_L:
return "Q3_K_L"
case FileTypeQ4_K_S:
case fileTypeQ4_K_S:
return "Q4_K_S"
case FileTypeQ4_K_M:
case fileTypeQ4_K_M:
return "Q4_K_M"
case fileTypeQ5_K_S:
return "Q5_K_S"
@ -121,198 +147,39 @@ func (t FileType) String() string {
return "Q5_K_M"
case fileTypeQ6_K:
return "Q6_K"
case fileTypeIQ2_XXS:
return "IQ2_XXS"
case fileTypeIQ2_XS:
return "IQ2_XS"
case fileTypeQ2_K_S:
return "Q2_K_S"
case FileTypeBF16:
case fileTypeIQ3_XS:
return "IQ3_XS"
case fileTypeIQ3_XXS:
return "IQ3_XXS"
case fileTypeIQ1_S:
return "IQ1_S"
case fileTypeIQ4_NL:
return "IQ4_NL"
case fileTypeIQ3_S:
return "IQ3_S"
case fileTypeIQ3_M:
return "IQ3_M"
case fileTypeIQ2_S:
return "IQ2_S"
case fileTypeIQ4_XS:
return "IQ4_XS"
case fileTypeIQ2_M:
return "IQ2_M"
case fileTypeIQ1_M:
return "IQ1_M"
case fileTypeBF16:
return "BF16"
default:
return "unknown"
}
}
func (t FileType) Value() uint32 {
func (t fileType) Value() uint32 {
return uint32(t)
}
func (ftype FileType) ToTensorType() TensorType {
switch ftype {
case FileTypeF32:
return TensorTypeF32
case FileTypeF16:
return TensorTypeF16
case fileTypeQ4_0:
return TensorTypeQ4_0
case fileTypeQ4_1:
return TensorTypeQ4_1
case FileTypeQ8_0:
return TensorTypeQ8_0
case fileTypeQ5_0:
return TensorTypeQ5_0
case fileTypeQ5_1:
return TensorTypeQ5_1
case fileTypeQ2_K:
return TensorTypeQ2_K
case fileTypeQ3_K_S:
return TensorTypeQ3_K
case fileTypeQ3_K_M:
return TensorTypeQ3_K
case fileTypeQ3_K_L:
return TensorTypeQ3_K
case FileTypeQ4_K_S:
return TensorTypeQ4_K
case FileTypeQ4_K_M:
return TensorTypeQ4_K
case fileTypeQ5_K_S:
return TensorTypeQ5_K
case fileTypeQ5_K_M:
return TensorTypeQ5_K
case fileTypeQ6_K:
return TensorTypeQ6_K
case fileTypeQ2_K_S:
return TensorTypeQ2_K
case FileTypeBF16:
return TensorTypeBF16
default:
slog.Warn("unsupported file type", "type", ftype)
return 0 // F32
}
}
// TensorType is equivalent to ggml_type for individual tensor types
// Note: these are not the same as FileType
type TensorType uint32
const (
TensorTypeF32 TensorType = iota
TensorTypeF16
TensorTypeQ4_0
TensorTypeQ4_1
tensorTypeQ4_2 // unused by GGML
tensorTypeQ4_3 // unused by GGML
TensorTypeQ5_0
TensorTypeQ5_1
TensorTypeQ8_0
TensorTypeQ8_1
TensorTypeQ2_K
TensorTypeQ3_K
TensorTypeQ4_K
TensorTypeQ5_K
TensorTypeQ6_K
TensorTypeQ8_K
tensorTypeIQ2_XXS // not supported by ollama
tensorTypeIQ2_XS // not supported by ollama
tensorTypeIQ3_XXS // not supported by ollama
tensorTypeIQ1_S // not supported by ollama
tensorTypeIQ4_NL // not supported by ollama
tensorTypeIQ3_S // not supported by ollama
tensorTypeIQ2_S // not supported by ollama
tensorTypeIQ4_XS // not supported by ollama
TensorTypeI8
TensorTypeI16
TensorTypeI32
TensorTypeI64
TensorTypeF64
tensorTypeIQ1_M // not supported by ollama
TensorTypeBF16
tensorTypeQ4_0_4_4 // unused by GGML
tensorTypeQ4_0_4_8 // unused by GGML
tensorTypeQ4_0_8_8 // unused by GGML
tensorTypeTQ1_0 // not supported by ollama
tensorTypeTQ2_0 // not supported by ollama
tensorTypeIQ4_NL_4_4 // unused by GGML
tensorTypeIQ4_NL_4_8 // unused by GGML
tensorTypeIQ4_NL_8_8 // unused by GGML
)
// ParseFileType parses the provided GGUF file type
// Only Ollama supported types are considered valid
func ParseTensorType(s string) (TensorType, error) {
switch s {
case "F32":
return TensorTypeF32, nil
case "F16":
return TensorTypeF16, nil
case "Q4_0":
return TensorTypeQ4_0, nil
case "Q4_1":
return TensorTypeQ4_1, nil
case "Q5_0":
return TensorTypeQ5_0, nil
case "Q5_1":
return TensorTypeQ5_1, nil
case "Q8_0":
return TensorTypeQ8_0, nil
case "Q8_1":
return TensorTypeQ8_1, nil
case "Q2_K":
return TensorTypeQ2_K, nil
case "Q3_K":
return TensorTypeQ3_K, nil
case "Q4_K":
return TensorTypeQ4_K, nil
case "Q5_K":
return TensorTypeQ5_K, nil
case "Q6_K":
return TensorTypeQ6_K, nil
case "Q8_K":
return TensorTypeQ8_K, nil
case "F64":
return TensorTypeF64, nil
case "BF16":
return TensorTypeBF16, nil
default:
return 0, fmt.Errorf("unsupported quantization type %s", s)
}
}
func (t TensorType) IsQuantized() bool {
switch t {
case TensorTypeF32, TensorTypeF16, TensorTypeBF16:
return false
default:
return true
}
}
func (t TensorType) RowSize(ne uint64) uint64 {
return t.TypeSize() * ne / t.BlockSize()
}
func (t TensorType) String() string {
switch t {
case TensorTypeF32:
return "F32"
case TensorTypeF16:
return "F16"
case TensorTypeQ4_0:
return "Q4_0"
case TensorTypeQ4_1:
return "Q4_1"
case TensorTypeQ5_0:
return "Q5_0"
case TensorTypeQ5_1:
return "Q5_1"
case TensorTypeQ8_0:
return "Q8_0"
case TensorTypeQ8_1:
return "Q8_1"
case TensorTypeQ2_K:
return "Q2_K"
case TensorTypeQ3_K:
return "Q3_K"
case TensorTypeQ4_K:
return "Q4_K"
case TensorTypeQ5_K:
return "Q5_K"
case TensorTypeQ6_K:
return "Q6_K"
case TensorTypeQ8_K:
return "Q8_K"
case TensorTypeF64:
return "F64"
case TensorTypeBF16:
return "BF16"
default:
return "unknown"
}
}

12
go.mod
View File

@ -11,7 +11,7 @@ require (
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.9.0
github.com/x448/float16 v0.8.4
golang.org/x/sync v0.12.0
golang.org/x/sync v0.11.0
)
require (
@ -70,12 +70,12 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.36.0
golang.org/x/crypto v0.33.0
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
golang.org/x/net v0.38.0 // indirect
golang.org/x/sys v0.31.0
golang.org/x/term v0.30.0
golang.org/x/text v0.23.0
golang.org/x/net v0.35.0 // indirect
golang.org/x/sys v0.30.0
golang.org/x/term v0.29.0
golang.org/x/text v0.22.0
google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1 // indirect
)

24
go.sum
View File

@ -214,8 +214,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@ -257,8 +257,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -268,8 +268,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -285,17 +285,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@ -1,412 +0,0 @@
//go:build integration
package integration
import (
"bytes"
"context"
"fmt"
"math/rand"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func TestAPIGenerate(t *testing.T) {
initialTimeout := 60 * time.Second
streamTimeout := 30 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()
// Set up the test data
req := api.GenerateRequest{
Model: smol,
Prompt: "why is the sky blue? be brief",
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
anyResp := []string{"rayleigh", "scattering"}
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("pull failed %s", err)
}
tests := []struct {
name string
stream bool
}{
{
name: "stream",
stream: true,
},
{
name: "no_stream",
stream: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
stallTimer := time.NewTimer(initialTimeout)
var buf bytes.Buffer
fn := func(response api.GenerateResponse) error {
// Fields that must always be present
if response.Model == "" {
t.Errorf("response missing model: %#v", response)
}
if response.Done {
// Required fields for final updates:
if response.DoneReason == "" && *req.Stream {
// TODO - is the lack of done reason on non-stream a bug?
t.Errorf("final response missing done_reason: %#v", response)
}
if response.Metrics.TotalDuration == 0 {
t.Errorf("final response missing total_duration: %#v", response)
}
if response.Metrics.LoadDuration == 0 {
t.Errorf("final response missing load_duration: %#v", response)
}
if response.Metrics.PromptEvalDuration == 0 {
t.Errorf("final response missing prompt_eval_duration: %#v", response)
}
if response.Metrics.EvalCount == 0 {
t.Errorf("final response missing eval_count: %#v", response)
}
if response.Metrics.EvalDuration == 0 {
t.Errorf("final response missing eval_duration: %#v", response)
}
if len(response.Context) == 0 {
t.Errorf("final response missing context: %#v", response)
}
// Note: caching can result in no prompt eval count, so this can't be verified reliably
// if response.Metrics.PromptEvalCount == 0 {
// t.Errorf("final response missing prompt_eval_count: %#v", response)
// }
} // else incremental response, nothing to check right now...
buf.Write([]byte(response.Response))
if !stallTimer.Reset(streamTimeout) {
return fmt.Errorf("stall was detected while streaming response, aborting")
}
return nil
}
done := make(chan int)
var genErr error
go func() {
req.Stream = &test.stream
req.Options["seed"] = rand.Int() // bust cache for prompt eval results
genErr = client.Generate(ctx, &req, fn)
done <- 0
}()
select {
case <-stallTimer.C:
if buf.Len() == 0 {
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
} else {
t.Errorf("generate stalled. Response so far:%s", buf.String())
}
case <-done:
if genErr != nil {
t.Fatalf("failed with %s request prompt %s ", req.Model, req.Prompt)
}
// Verify the response contains the expected data
response := buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Errorf("none of %v found in %s", anyResp, response)
}
case <-ctx.Done():
t.Error("outer test context done while waiting for generate")
}
})
}
// Validate PS while we're at it...
resp, err := client.ListRunning(ctx)
if err != nil {
t.Fatalf("list models API error: %s", err)
}
if resp == nil || len(resp.Models) == 0 {
t.Fatalf("list models API returned empty list while model should still be loaded")
}
// Find the model we just loaded and verify some attributes
found := false
for _, model := range resp.Models {
if strings.Contains(model.Name, req.Model) {
found = true
if model.Model == "" {
t.Errorf("model field omitted: %#v", model)
}
if model.Size == 0 {
t.Errorf("size omitted: %#v", model)
}
if model.Digest == "" {
t.Errorf("digest omitted: %#v", model)
}
verifyModelDetails(t, model.Details)
var nilTime time.Time
if model.ExpiresAt == nilTime {
t.Errorf("expires_at omitted: %#v", model)
}
// SizeVRAM could be zero.
}
}
if !found {
t.Errorf("unable to locate running model: %#v", resp)
}
}
func TestAPIChat(t *testing.T) {
initialTimeout := 60 * time.Second
streamTimeout := 30 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()
// Set up the test data
req := api.ChatRequest{
Model: smol,
Messages: []api.Message{
{
Role: "user",
Content: "why is the sky blue? be brief",
},
},
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
anyResp := []string{"rayleigh", "scattering"}
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("pull failed %s", err)
}
tests := []struct {
name string
stream bool
}{
{
name: "stream",
stream: true,
},
{
name: "no_stream",
stream: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
stallTimer := time.NewTimer(initialTimeout)
var buf bytes.Buffer
fn := func(response api.ChatResponse) error {
// Fields that must always be present
if response.Model == "" {
t.Errorf("response missing model: %#v", response)
}
if response.Done {
// Required fields for final updates:
var nilTime time.Time
if response.CreatedAt == nilTime {
t.Errorf("final response missing total_duration: %#v", response)
}
if response.DoneReason == "" {
t.Errorf("final response missing done_reason: %#v", response)
}
if response.Metrics.TotalDuration == 0 {
t.Errorf("final response missing total_duration: %#v", response)
}
if response.Metrics.LoadDuration == 0 {
t.Errorf("final response missing load_duration: %#v", response)
}
if response.Metrics.PromptEvalDuration == 0 {
t.Errorf("final response missing prompt_eval_duration: %#v", response)
}
if response.Metrics.EvalCount == 0 {
t.Errorf("final response missing eval_count: %#v", response)
}
if response.Metrics.EvalDuration == 0 {
t.Errorf("final response missing eval_duration: %#v", response)
}
if response.Metrics.PromptEvalCount == 0 {
t.Errorf("final response missing prompt_eval_count: %#v", response)
}
} // else incremental response, nothing to check right now...
buf.Write([]byte(response.Message.Content))
if !stallTimer.Reset(streamTimeout) {
return fmt.Errorf("stall was detected while streaming response, aborting")
}
return nil
}
done := make(chan int)
var genErr error
go func() {
req.Stream = &test.stream
req.Options["seed"] = rand.Int() // bust cache for prompt eval results
genErr = client.Chat(ctx, &req, fn)
done <- 0
}()
select {
case <-stallTimer.C:
if buf.Len() == 0 {
t.Errorf("chat never started. Timed out after :%s", initialTimeout.String())
} else {
t.Errorf("chat stalled. Response so far:%s", buf.String())
}
case <-done:
if genErr != nil {
t.Fatalf("failed with %s request prompt %v", req.Model, req.Messages)
}
// Verify the response contains the expected data
response := buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Errorf("none of %v found in %s", anyResp, response)
}
case <-ctx.Done():
t.Error("outer test context done while waiting for chat")
}
})
}
}
func TestAPIListModels(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Make sure we have at least one model so an empty list can be considered a failure
if err := PullIfMissing(ctx, client, smol); err != nil {
t.Fatalf("pull failed %s", err)
}
resp, err := client.List(ctx)
if err != nil {
t.Fatalf("unable to list models: %s", err)
}
if len(resp.Models) == 0 {
t.Fatalf("list should not be empty")
}
model := resp.Models[0]
if model.Name == "" {
t.Errorf("first model name empty: %#v", model)
}
var nilTime time.Time
if model.ModifiedAt == nilTime {
t.Errorf("first model modified_at empty: %#v", model)
}
if model.Size == 0 {
t.Errorf("first model size empty: %#v", model)
}
if model.Digest == "" {
t.Errorf("first model digest empty: %#v", model)
}
verifyModelDetails(t, model.Details)
}
func verifyModelDetails(t *testing.T, details api.ModelDetails) {
if details.Format == "" {
t.Errorf("first model details.format empty: %#v", details)
}
if details.Family == "" {
t.Errorf("first model details.family empty: %#v", details)
}
if details.ParameterSize == "" {
t.Errorf("first model details.parameter_size empty: %#v", details)
}
if details.QuantizationLevel == "" {
t.Errorf("first model details.quantization_level empty: %#v", details)
}
}
func TestAPIShowModel(t *testing.T) {
modelName := "llama3.2"
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, modelName); err != nil {
t.Fatalf("pull failed %s", err)
}
resp, err := client.Show(ctx, &api.ShowRequest{Name: modelName})
if err != nil {
t.Fatalf("unable to show model: %s", err)
}
if resp.License == "" {
t.Errorf("%s missing license: %#v", modelName, resp)
}
if resp.Modelfile == "" {
t.Errorf("%s missing modelfile: %#v", modelName, resp)
}
if resp.Parameters == "" {
t.Errorf("%s missing parameters: %#v", modelName, resp)
}
if resp.Template == "" {
t.Errorf("%s missing template: %#v", modelName, resp)
}
// llama3 omits system
verifyModelDetails(t, resp.Details)
// llama3 ommits messages
if len(resp.ModelInfo) == 0 {
t.Errorf("%s missing model_info: %#v", modelName, resp)
}
// llama3 omits projectors
var nilTime time.Time
if resp.ModifiedAt == nilTime {
t.Errorf("%s missing modified_at: %#v", modelName, resp)
}
}
func TestAPIEmbeddings(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req := api.EmbeddingRequest{
Model: "orca-mini",
Prompt: "why is the sky blue?",
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("pull failed %s", err)
}
resp, err := client.Embeddings(ctx, &req)
if err != nil {
t.Fatalf("embeddings call failed %s", err)
}
if len(resp.Embedding) == 0 {
t.Errorf("zero length embedding response")
}
}

View File

@ -14,12 +14,12 @@ import (
"github.com/stretchr/testify/require"
)
func TestBlueSky(t *testing.T) {
func TestOrcaMiniBlueSky(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
// Set up the test data
req := api.GenerateRequest{
Model: smol,
Model: "orca-mini",
Prompt: "why is the sky blue?",
Stream: &stream,
Options: map[string]any{
@ -31,7 +31,6 @@ func TestBlueSky(t *testing.T) {
}
func TestUnicode(t *testing.T) {
skipUnderMinVRAM(t, 6)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
// Set up the test data
@ -94,7 +93,7 @@ func TestUnicodeModelDir(t *testing.T) {
defer cancel()
req := api.GenerateRequest{
Model: smol,
Model: "orca-mini",
Prompt: "why is the sky blue?",
Stream: &stream,
Options: map[string]any{

View File

@ -21,7 +21,7 @@ func TestMultiModelConcurrency(t *testing.T) {
var (
req = [2]api.GenerateRequest{
{
Model: "llama3.2:1b",
Model: "orca-mini",
Prompt: "why is the ocean blue?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
@ -67,7 +67,7 @@ func TestMultiModelConcurrency(t *testing.T) {
wg.Wait()
}
func TestIntegrationConcurrentPredict(t *testing.T) {
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
req, resp := GenerateRequests()
reqLimit := len(req)
iterLimit := 5
@ -117,9 +117,6 @@ func TestMultiModelStress(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if maxVram < 2*format.GibiByte {
t.Skip("VRAM less than 2G, skipping model stress tests")
}
type model struct {
name string
@ -128,8 +125,8 @@ func TestMultiModelStress(t *testing.T) {
smallModels := []model{
{
name: "llama3.2:1b",
size: 2876 * format.MebiByte,
name: "orca-mini",
size: 2992 * format.MebiByte,
},
{
name: "phi",

View File

@ -34,15 +34,13 @@ func cosineSimilarity[V float32 | float64](v1, v2 []V) V {
func TestAllMiniLMEmbeddings(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req := api.EmbeddingRequest{
Model: "all-minilm",
Prompt: "why is the sky blue?",
}
res, err := embeddingTestHelper(ctx, client, t, req)
res, err := embeddingTestHelper(ctx, t, req)
if err != nil {
t.Fatalf("error: %v", err)
@ -64,15 +62,13 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
func TestAllMiniLMEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req := api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
}
res, err := embedTestHelper(ctx, client, t, req)
res, err := embedTestHelper(ctx, t, req)
if err != nil {
t.Fatalf("error: %v", err)
@ -102,15 +98,13 @@ func TestAllMiniLMEmbed(t *testing.T) {
func TestAllMiniLMBatchEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req := api.EmbedRequest{
Model: "all-minilm",
Input: []string{"why is the sky blue?", "why is the grass green?"},
}
res, err := embedTestHelper(ctx, client, t, req)
res, err := embedTestHelper(ctx, t, req)
if err != nil {
t.Fatalf("error: %v", err)
@ -150,8 +144,6 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
func TestAllMiniLMEmbedTruncate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
truncTrue, truncFalse := true, false
@ -190,7 +182,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
res := make(map[string]*api.EmbedResponse)
for _, req := range reqs {
response, err := embedTestHelper(ctx, client, t, req.Request)
response, err := embedTestHelper(ctx, t, req.Request)
if err != nil {
t.Fatalf("error: %v", err)
}
@ -206,7 +198,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
}
// check that truncate set to false returns an error if context length is exceeded
_, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
_, err := embedTestHelper(ctx, t, api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncFalse,
@ -218,7 +210,9 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
}
}
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err)
}
@ -232,7 +226,9 @@ func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T,
return response, nil
}
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err)
}

View File

@ -12,51 +12,58 @@ import (
"github.com/stretchr/testify/require"
)
func TestVisionModels(t *testing.T) {
skipUnderMinVRAM(t, 6)
type testCase struct {
model string
}
testCases := []testCase{
{
model: "llava:7b",
func TestIntegrationLlava(t *testing.T) {
image, err := base64.StdEncoding.DecodeString(imageEncoding)
require.NoError(t, err)
req := api.GenerateRequest{
Model: "llava:7b",
Prompt: "what does the text in this image say?",
Stream: &stream,
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
{
model: "llama3.2-vision",
},
{
model: "gemma3",
Images: []api.ImageData{
image,
},
}
for _, v := range testCases {
t.Run(v.model, func(t *testing.T) {
image, err := base64.StdEncoding.DecodeString(imageEncoding)
require.NoError(t, err)
req := api.GenerateRequest{
Model: v.model,
Prompt: "what does the text in this image say?",
Stream: &stream,
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
Images: []api.ImageData{
image,
},
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
resp := "the ollam"
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model))
// llava models on CPU can be quite slow to start,
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
}
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
resp := "the ollam"
defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model))
// llava models on CPU can be quite slow to start
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
})
func TestIntegrationMllama(t *testing.T) {
image, err := base64.StdEncoding.DecodeString(imageEncoding)
require.NoError(t, err)
req := api.GenerateRequest{
// TODO fix up once we publish the final image
Model: "x/llama3.2-vision",
Prompt: "what does the text in this image say?",
Stream: &stream,
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
Images: []api.ImageData{
image,
},
}
resp := "the ollamas"
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model))
// mllama models on CPU can be quite slow to start,
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
}
func TestIntegrationSplitBatch(t *testing.T) {

View File

@ -17,7 +17,7 @@ var (
stream = false
req = [2]api.GenerateRequest{
{
Model: smol,
Model: "orca-mini",
Prompt: "why is the ocean blue?",
Stream: &stream,
Options: map[string]any{
@ -25,7 +25,7 @@ var (
"temperature": 0.0,
},
}, {
Model: smol,
Model: "orca-mini",
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
Options: map[string]any{
@ -35,12 +35,12 @@ var (
},
}
resp = [2][]string{
{"sunlight", "scattering", "interact"},
{"sunlight"},
{"england", "english", "massachusetts", "pilgrims"},
}
)
func TestIntegrationSimple(t *testing.T) {
func TestIntegrationSimpleOrcaMini(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel()
GenerateTestHelper(ctx, t, req[0], resp[0])

View File

@ -30,7 +30,7 @@ func TestMaxQueue(t *testing.T) {
t.Setenv("OLLAMA_MAX_QUEUE", strconv.Itoa(threadCount))
req := api.GenerateRequest{
Model: smol,
Model: "orca-mini",
Prompt: "write a long historical fiction story about christopher columbus. use at least 10 facts from his actual journey",
Options: map[string]any{
"seed": 42,
@ -52,8 +52,8 @@ func TestMaxQueue(t *testing.T) {
embedCtx := ctx
var genwg sync.WaitGroup
genwg.Add(1)
go func() {
genwg.Add(1)
defer genwg.Done()
slog.Info("Starting generate request")
DoGenerate(ctx, t, client, req, resp, 45*time.Second, 5*time.Second)
@ -61,7 +61,7 @@ func TestMaxQueue(t *testing.T) {
}()
// Give the generate a chance to get started before we start hammering on embed requests
time.Sleep(10 * time.Millisecond)
time.Sleep(5 * time.Millisecond)
threadCount += 10 // Add a few extra to ensure we push the queue past its limit
busyCount := 0
@ -71,8 +71,8 @@ func TestMaxQueue(t *testing.T) {
counterMu := sync.Mutex{}
var embedwg sync.WaitGroup
for i := 0; i < threadCount; i++ {
embedwg.Add(1)
go func(i int) {
embedwg.Add(1)
defer embedwg.Done()
slog.Info("embed started", "id", i)
embedReq := api.EmbeddingRequest{

View File

@ -1,184 +0,0 @@
//go:build integration && models
package integration
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"log/slog"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
)
var (
started = time.Now()
chatModels = []string{
"granite3-moe:latest",
"granite-code:latest",
"nemotron-mini:latest",
"command-r:latest",
"gemma2:latest",
"gemma:latest",
"internlm2:latest",
"phi3.5:latest",
"phi3:latest",
// "phi:latest", // flaky, sometimes generates no response on first query
"stablelm2:latest", // Predictions are off, crashes on small VRAM GPUs
"falcon:latest",
"falcon2:latest",
"minicpm-v:latest",
"mistral:latest",
"orca-mini:latest",
"llama2:latest",
"llama3.1:latest",
"llama3.2:latest",
"llama3.2-vision:latest",
"qwen2.5-coder:latest",
"qwen:latest",
"solar-pro:latest",
}
)
func TestModelsGenerate(t *testing.T) {
softTimeout, hardTimeout := getTimeouts(t)
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// TODO use info API eventually
var maxVram uint64
var err error
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
maxVram, err = strconv.ParseUint(s, 10, 64)
if err != nil {
t.Fatalf("invalid OLLAMA_MAX_VRAM %v", err)
}
} else {
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
}
for _, model := range chatModels {
t.Run(model, func(t *testing.T) {
if time.Now().Sub(started) > softTimeout {
t.Skip("skipping remaining tests to avoid excessive runtime")
}
if err := PullIfMissing(ctx, client, model); err != nil {
t.Fatalf("pull failed %s", err)
}
if maxVram > 0 {
resp, err := client.List(ctx)
if err != nil {
t.Fatalf("list models failed %v", err)
}
for _, m := range resp.Models {
if m.Name == model && float32(m.Size)*1.2 > float32(maxVram) {
t.Skipf("model %s is too large for available VRAM: %s > %s", model, format.HumanBytes(m.Size), format.HumanBytes(int64(maxVram)))
}
}
}
// TODO - fiddle with context size
req := api.GenerateRequest{
Model: model,
Prompt: "why is the sky blue?",
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
anyResp := []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}
DoGenerate(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second)
})
}
}
func TestModelsEmbed(t *testing.T) {
softTimeout, hardTimeout := getTimeouts(t)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// TODO use info API eventually
var maxVram uint64
var err error
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
maxVram, err = strconv.ParseUint(s, 10, 64)
if err != nil {
t.Fatalf("invalid OLLAMA_MAX_VRAM %v", err)
}
} else {
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
}
data, err := ioutil.ReadFile(filepath.Join("testdata", "embed.json"))
if err != nil {
t.Fatalf("failed to open test data file: %s", err)
}
testCase := map[string][]float64{}
err = json.Unmarshal(data, &testCase)
if err != nil {
t.Fatalf("failed to load test data: %s", err)
}
for model, expected := range testCase {
t.Run(model, func(t *testing.T) {
if time.Now().Sub(started) > softTimeout {
t.Skip("skipping remaining tests to avoid excessive runtime")
}
if err := PullIfMissing(ctx, client, model); err != nil {
t.Fatalf("pull failed %s", err)
}
if maxVram > 0 {
resp, err := client.List(ctx)
if err != nil {
t.Fatalf("list models failed %v", err)
}
for _, m := range resp.Models {
if m.Name == model && float32(m.Size)*1.2 > float32(maxVram) {
t.Skipf("model %s is too large for available VRAM: %s > %s", model, format.HumanBytes(m.Size), format.HumanBytes(int64(maxVram)))
}
}
}
req := api.EmbeddingRequest{
Model: model,
Prompt: "why is the sky blue?",
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
resp, err := client.Embeddings(ctx, &req)
if err != nil {
t.Fatalf("embeddings call failed %s", err)
}
if len(resp.Embedding) == 0 {
t.Errorf("zero length embedding response")
}
if len(expected) != len(resp.Embedding) {
expStr := make([]string, len(resp.Embedding))
for i, v := range resp.Embedding {
expStr[i] = fmt.Sprintf("%0.6f", v)
}
// When adding new models, use this output to populate the testdata/embed.json
fmt.Printf("expected\n%s\n", strings.Join(expStr, ", "))
t.Fatalf("expected %d, got %d", len(expected), len(resp.Embedding))
}
sim := cosineSimilarity(resp.Embedding, expected)
if sim < 0.99 {
t.Fatalf("expected %v, got %v (similarity: %f)", expected[0:5], resp.Embedding[0:5], sim)
}
})
}
}

View File

@ -1,130 +0,0 @@
//go:build integration && models
package integration
import (
"bytes"
"context"
"fmt"
"log/slog"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func TestQuantization(t *testing.T) {
sourceModels := []string{
"qwen2.5:0.5b-instruct-fp16",
}
quantizations := []string{
"Q8_0",
"Q4_K_S",
"Q4_K_M",
"Q4_K",
}
softTimeout, hardTimeout := getTimeouts(t)
started := time.Now()
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
for _, base := range sourceModels {
if err := PullIfMissing(ctx, client, base); err != nil {
t.Fatalf("pull failed %s", err)
}
for _, quant := range quantizations {
newName := fmt.Sprintf("%s__%s", base, quant)
t.Run(newName, func(t *testing.T) {
if time.Now().Sub(started) > softTimeout {
t.Skip("skipping remaining tests to avoid excessive runtime")
}
req := &api.CreateRequest{
Model: newName,
Quantization: quant,
From: base,
}
fn := func(resp api.ProgressResponse) error {
// fmt.Print(".")
return nil
}
t.Logf("quantizing: %s -> %s", base, quant)
if err := client.Create(ctx, req, fn); err != nil {
t.Fatalf("create failed %s", err)
}
defer func() {
req := &api.DeleteRequest{
Model: newName,
}
t.Logf("deleting: %s -> %s", base, quant)
if err := client.Delete(ctx, req); err != nil {
t.Logf("failed to clean up %s: %s", req.Model, err)
}
}()
// Check metadata on the model
resp, err := client.Show(ctx, &api.ShowRequest{Name: newName})
if err != nil {
t.Fatalf("unable to show model: %s", err)
}
if !strings.Contains(resp.Details.QuantizationLevel, quant) {
t.Fatalf("unexpected quantization for %s:\ngot: %s", newName, resp.Details.QuantizationLevel)
}
stream := true
genReq := api.GenerateRequest{
Model: newName,
Prompt: "why is the sky blue?",
KeepAlive: &api.Duration{Duration: 3 * time.Second},
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
Stream: &stream,
}
t.Logf("verifying: %s -> %s", base, quant)
// Some smaller quantizations can cause models to have poor quality
// or get stuck in repetition loops, so we stop as soon as we have any matches
anyResp := []string{"rayleigh", "scattering", "day", "sun", "moon", "color", "nitrogen", "oxygen"}
reqCtx, reqCancel := context.WithCancel(ctx)
atLeastOne := false
var buf bytes.Buffer
genfn := func(response api.GenerateResponse) error {
buf.Write([]byte(response.Response))
fullResp := strings.ToLower(buf.String())
for _, resp := range anyResp {
if strings.Contains(fullResp, resp) {
atLeastOne = true
t.Log(fullResp)
reqCancel()
break
}
}
return nil
}
done := make(chan int)
var genErr error
go func() {
genErr = client.Generate(reqCtx, &genReq, genfn)
done <- 0
}()
select {
case <-done:
if genErr != nil && !atLeastOne {
t.Fatalf("failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
}
case <-ctx.Done():
t.Error("outer test context done while waiting for generate")
}
t.Logf("passed")
})
}
}
}

File diff suppressed because one or more lines are too long

View File

@ -24,14 +24,9 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle"
"github.com/ollama/ollama/format"
"github.com/stretchr/testify/require"
)
const (
smol = "llama3.2:1b"
)
func Init() {
lifecycle.InitLogging()
}
@ -145,7 +140,7 @@ func PullIfMissing(ctx context.Context, client *api.Client, modelName string) er
showCtx, cancel := context.WithDeadlineCause(
ctx,
time.Now().Add(20*time.Second),
time.Now().Add(10*time.Second),
fmt.Errorf("show for existing model %s took too long", modelName),
)
defer cancel()
@ -162,7 +157,7 @@ func PullIfMissing(ctx context.Context, client *api.Client, modelName string) er
}
slog.Info("model missing", "model", modelName)
stallDuration := 60 * time.Second // This includes checksum verification, which can take a while on larger models, and slower systems
stallDuration := 30 * time.Second // This includes checksum verification, which can take a while on larger models
stallTimer := time.NewTimer(stallDuration)
fn := func(resp api.ProgressResponse) error {
// fmt.Print(".")
@ -217,7 +212,6 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
return
}
defer fp.Close()
data, err := io.ReadAll(fp)
if err != nil {
slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
@ -289,11 +283,11 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
}
// Generate a set of requests
// By default each request uses llama3.2 as the model
// By default each request uses orca-mini as the model
func GenerateRequests() ([]api.GenerateRequest, [][]string) {
return []api.GenerateRequest{
{
Model: smol,
Model: "orca-mini",
Prompt: "why is the ocean blue?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
@ -302,7 +296,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
"temperature": 0.0,
},
}, {
Model: smol,
Model: "orca-mini",
Prompt: "why is the color of dirt brown?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
@ -311,7 +305,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
"temperature": 0.0,
},
}, {
Model: smol,
Model: "orca-mini",
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
@ -320,7 +314,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
"temperature": 0.0,
},
}, {
Model: smol,
Model: "orca-mini",
Prompt: "what is the origin of independence day?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
@ -329,7 +323,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
"temperature": 0.0,
},
}, {
Model: smol,
Model: "orca-mini",
Prompt: "what is the composition of air?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
@ -347,26 +341,3 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
{"nitrogen", "oxygen", "carbon", "dioxide"},
}
}
func skipUnderMinVRAM(t *testing.T, gb uint64) {
// TODO use info API in the future
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
maxVram, err := strconv.ParseUint(s, 10, 64)
require.NoError(t, err)
// Don't hammer on small VRAM cards...
if maxVram < gb*format.GibiByte {
t.Skip("skipping with small VRAM to avoid timeouts")
}
}
}
func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) {
deadline, hasDeadline := t.Deadline()
if !hasDeadline {
return 8 * time.Minute, 10 * time.Minute
} else if deadline.Compare(time.Now().Add(2*time.Minute)) <= 0 {
t.Skip("too little time")
return time.Duration(0), time.Duration(0)
}
return -time.Since(deadline.Add(-2 * time.Minute)), -time.Since(deadline.Add(-20 * time.Second))
}

View File

@ -56,9 +56,8 @@ type Cache interface {
// StartForward is called before the start of the model's forward pass.
// For each token in the coming batch, there must be a corresponding
// entry in positions and seqs. reserve is to preallocate memory
// without actually storing data in the cache.
StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
// entry in positions and seqs.
StartForward(ctx ml.Context, batch input.Batch) error
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix(srcSeq, dstSeq int, len int32)

View File

@ -21,7 +21,6 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
type Causal struct {
DType ml.DType
windowSize int32
chunkSize int32
opts CausalOptions
@ -98,17 +97,6 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
}
}
func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
return &Causal{
windowSize: math.MaxInt32,
chunkSize: chunkSize,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
}
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
if c.config == nil {
var config ml.CacheConfig
@ -158,60 +146,51 @@ func (c *Causal) Close() {
}
}
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error {
c.curBatchSize = len(batch.Positions)
c.curSequences = batch.Sequences
c.curPositions = batch.Positions
c.opts.Except = nil
if !reserve {
c.updateSlidingWindow()
var err error
c.curLoc, err = c.findStartLoc()
if errors.Is(err, ErrKvCacheFull) {
c.defrag()
c.curLoc, err = c.findStartLoc()
}
if err != nil {
return err
}
c.curCellRange = newRange()
for i, pos := range batch.Positions {
seq := batch.Sequences[i]
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
seqRange, ok := c.cellRanges[seq]
if !ok {
seqRange = newRange()
}
if c.curLoc+i > seqRange.max {
seqRange.max = c.curLoc + i
}
if seqRange.max > c.curCellRange.max {
c.curCellRange.max = seqRange.max
}
if c.curLoc+i < seqRange.min {
seqRange.min = c.curLoc + i
}
if seqRange.min < c.curCellRange.min {
c.curCellRange.min = seqRange.min
}
c.cellRanges[seq] = seqRange
}
} else {
// If we are reserving memory, don't update any of the cache metadata but set the size
// to the worst case.
c.curLoc = 0
c.curCellRange.min = 0
c.curCellRange.max = len(c.cells) - 1
}
c.updateSlidingWindow()
var err error
c.curLoc, err = c.findStartLoc()
if errors.Is(err, ErrKvCacheFull) {
c.defrag()
c.curLoc, err = c.findStartLoc()
}
if err != nil {
return err
}
c.curCellRange = newRange()
for i, pos := range batch.Positions {
seq := batch.Sequences[i]
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
seqRange, ok := c.cellRanges[seq]
if !ok {
seqRange = newRange()
}
if c.curLoc+i > seqRange.max {
seqRange.max = c.curLoc + i
}
if seqRange.max > c.curCellRange.max {
c.curCellRange.max = seqRange.max
}
if c.curLoc+i < seqRange.min {
seqRange.min = c.curLoc + i
}
if seqRange.min < c.curCellRange.min {
c.curCellRange.min = seqRange.min
}
c.cellRanges[seq] = seqRange
}
c.curMask, err = c.buildMask(ctx)
return err
@ -239,7 +218,7 @@ func (c *Causal) findStartLoc() (int, error) {
}
}
return 0, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells))
}
func (c *Causal) updateSlidingWindow() {
@ -312,7 +291,6 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
(enabled && c.cells[j].pos > c.curPositions[i]) ||
c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
c.cells[j].pos < c.curPositions[i]-c.windowSize {
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
}

View File

@ -5,6 +5,7 @@ import (
"slices"
"testing"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
@ -86,64 +87,6 @@ func TestSWA(t *testing.T) {
testCache(t, backend, cache, tests)
}
func TestChunkedAttention(t *testing.T) {
cache := NewChunkedAttentionCache(2, nil)
defer cache.Close()
var b testBackend
cache.Init(&b, ml.DTypeF16, 1, 16, 16)
x := float32(math.Inf(-1))
testCache(
t, &b, cache,
[]testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{
0, x, x, x,
0, 0, x, x,
x, x, 0, x,
x, x, 0, 0,
},
},
{
name: "SecondBatch",
in: []float32{5, 6, 7},
inShape: []int{1, 1, 3},
seqs: []int{0, 0, 0},
pos: []int32{4, 5, 6},
expected: []float32{1, 2, 3, 4, 5, 6, 7},
expectedShape: []int{1, 1, 7},
expectedMask: []float32{
x, x, x, x, 0, x, x,
x, x, x, x, 0, 0, x,
x, x, x, x, x, x, 0,
},
},
{
name: "ThirdBatch",
in: []float32{8, 9},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{7, 8},
expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
expectedShape: []int{1, 1, 9},
expectedMask: []float32{
x, x, x, x, x, x, 0, 0, x,
x, x, x, x, x, x, x, x, 0,
},
},
},
)
}
func TestSequences(t *testing.T) {
backend := &testBackend{}
cache := NewCausalCache(nil)
@ -338,7 +281,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
context := backend.NewContext()
defer context.Close()
err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false)
err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs})
if err != nil {
panic(err)
}
@ -351,16 +294,8 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
context.Forward(out, mask).Compute(out, mask)
if !slices.Equal(out.Floats(), test.expected) {
t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected)
}
if !slices.Equal(out.Shape(), test.expectedShape) {
t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape)
}
if !slices.Equal(mask.Floats(), test.expectedMask) {
t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask)
if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) {
t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask)
}
})
}
@ -380,7 +315,7 @@ func TestCanResume(t *testing.T) {
err := cache.StartForward(context, input.Batch{
Positions: []int32{0, 1, 2, 3},
Sequences: []int{0, 0, 0, 0},
}, false)
})
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
@ -407,7 +342,7 @@ func TestCanResume(t *testing.T) {
err = cache.StartForward(context, input.Batch{
Positions: []int32{4, 5},
Sequences: []int{0, 0},
}, false)
})
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
@ -437,8 +372,14 @@ func TestCanResume(t *testing.T) {
}
}
type testBackend struct {
ml.Backend
type testBackend struct{}
func (b *testBackend) Config() fs.Config {
panic("not implemented")
}
func (b *testBackend) Get(name string) ml.Tensor {
panic("not implemented")
}
func (b *testBackend) NewContext() ml.Context {
@ -449,10 +390,12 @@ func (b *testBackend) NewContextSize(int) ml.Context {
return &testContext{}
}
type testContext struct {
ml.Context
func (b *testBackend) SystemInfo() string {
return "not implemented"
}
type testContext struct{}
func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
total := 0
@ -490,17 +433,6 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
return out, nil
}
func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
s := make([]float32, 0, int((stop-start)/step))
for i := start; i < stop; i += step {
s = append(s, i)
}
out, _ := c.FromFloatSlice(s, len(s))
out.(*testTensor).dtype = dtype
return out
}
func (c *testContext) Input() ml.Context { return c }
func (c *testContext) Layer(int) ml.Context { return c }
@ -508,8 +440,6 @@ func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
func (c *testContext) Compute(...ml.Tensor) {}
func (c *testContext) Reserve() error { return nil }
func (c *testContext) MaxGraphNodes() int {
return 10
}
@ -517,8 +447,6 @@ func (c *testContext) MaxGraphNodes() int {
func (c *testContext) Close() {}
type testTensor struct {
ml.Tensor
dtype ml.DType
elementSize int
data []float32
@ -546,6 +474,10 @@ func (t *testTensor) DType() ml.DType {
return t.dtype
}
func (t *testTensor) Bytes() []byte {
panic("not implemented")
}
func (t *testTensor) Floats() []float32 {
out := make([]float32, len(t.data))
copy(out, t.data)
@ -570,6 +502,64 @@ func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return out
}
func (t *testTensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Softmax(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) LayerNorm(ctx ml.Context, weight, bias ml.Tensor, eps float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) RMSNorm(ctx ml.Context, weight ml.Tensor, eps float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) IM2Col(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Cos(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) Sin(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) GELU(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) SILU(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
offset /= t.elementSize
@ -592,7 +582,43 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
return view
}
func (t *testTensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Repeat(ctx ml.Context, dim, n int) ml.Tensor { panic("not implemented") }
func (t *testTensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
copy(t2.(*testTensor).data, t.data)
return nil
}
func (t *testTensor) Duplicate(ctx ml.Context) ml.Tensor { panic("not implemented") }

View File

@ -27,11 +27,6 @@ type EncoderCache struct {
// anything will be stored)
curPos int32
// curReserve indicates that this forward pass is only for
// memory reservation and we should not update our metadata
// based on it.
curReserve bool
// ** cache metadata **
// was something stored in the cache?
@ -88,14 +83,12 @@ func (c *EncoderCache) Close() {
}
}
func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch) error {
// We work with the most recent image
if len(batch.Multimodal) > 0 {
c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
}
c.curReserve = reserve
return nil
}
@ -112,10 +105,8 @@ func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
}
func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
if !c.curReserve {
c.encoderPos = c.curPos
c.encoderCached = true
}
c.encoderPos = c.curPos
c.encoderCached = true
if c.config.PermutedV {
value = value.Permute(ctx, 1, 2, 0, 3)

View File

@ -41,9 +41,9 @@ func (c *WrapperCache) Close() {
}
}
func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch) error {
for i, cache := range c.caches {
err := cache.StartForward(ctx, batch, reserve)
err := cache.StartForward(ctx, batch)
if err != nil {
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
for j := i - 1; j >= 0; j-- {

2
llama/build-info.cpp generated vendored
View File

@ -1,4 +1,4 @@
int LLAMA_BUILD_NUMBER = 0;
char const *LLAMA_COMMIT = "de4c07f93783a1a96456a44dc16b9db538ee1618";
char const *LLAMA_COMMIT = "d7cfe1ffe0f435d0048a6058d529daf76e072d9c";
char const *LLAMA_COMPILER = "";
char const *LLAMA_BUILD_TARGET = "";

View File

@ -10,11 +10,10 @@ include common/stb_image.*
include include/
include include/llama.*
include include/llama-*.*
include tools/
include tools/mtmd/
include tools/mtmd/clip.*
include tools/mtmd/clip-impl.*
include tools/mtmd/llava.*
include examples/
include examples/llava/
include examples/llava/clip.*
include examples/llava/llava.*
include src/
include src/llama.*
include src/llama-*.*

View File

@ -7,6 +7,10 @@
#include "common.h"
#include "log.h"
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
#include "json-schema-to-grammar.h"
#include "llama.h"
#include <algorithm>
@ -48,11 +52,47 @@
#include <sys/stat.h>
#include <unistd.h>
#endif
#if defined(LLAMA_USE_CURL)
#include <curl/curl.h>
#include <curl/easy.h>
#include <future>
#endif
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
#if defined(LLAMA_USE_CURL)
#ifdef __linux__
#include <linux/limits.h>
#elif defined(_WIN32)
# if !defined(PATH_MAX)
# define PATH_MAX MAX_PATH
# endif
#else
#include <sys/syslimits.h>
#endif
#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
//
// CURL utils
//
using curl_ptr = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
// cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one
struct curl_slist_ptr {
struct curl_slist * ptr = nullptr;
~curl_slist_ptr() {
if (ptr) {
curl_slist_free_all(ptr);
}
}
};
#endif // LLAMA_USE_CURL
using json = nlohmann::ordered_json;
//
// CPU utils
//
@ -443,11 +483,6 @@ void string_replace_all(std::string & s, const std::string & search, const std::
s = std::move(builder);
}
std::string regex_escape(const std::string & s) {
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
return std::regex_replace(s, special_chars, "\\$0");
}
std::string string_join(const std::vector<std::string> & values, const std::string & separator) {
std::ostringstream result;
for (size_t i = 0; i < values.size(); ++i) {
@ -830,7 +865,7 @@ std::string fs_get_cache_directory() {
if (getenv("LLAMA_CACHE")) {
cache_directory = std::getenv("LLAMA_CACHE");
} else {
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX)
#ifdef __linux__
if (std::getenv("XDG_CACHE_HOME")) {
cache_directory = std::getenv("XDG_CACHE_HOME");
} else {
@ -840,9 +875,7 @@ std::string fs_get_cache_directory() {
cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
#elif defined(_WIN32)
cache_directory = std::getenv("LOCALAPPDATA");
#else
# error Unknown architecture
#endif
#endif // __linux__
cache_directory = ensure_trailing_slash(cache_directory);
cache_directory += "llama.cpp";
}
@ -863,14 +896,22 @@ std::string fs_get_cache_file(const std::string & filename) {
//
// Model utils
//
struct common_init_result common_init_from_params(common_params & params) {
common_init_result iparams;
auto mparams = common_model_params_to_llama(params);
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
llama_model * model = nullptr;
if (!params.hf_repo.empty() && !params.hf_file.empty()) {
model = common_load_model_from_hf(params.hf_repo, params.hf_file, params.model, params.hf_token, mparams);
} else if (!params.model_url.empty()) {
model = common_load_model_from_url(params.model_url, params.model, params.hf_token, mparams);
} else {
model = llama_model_load_from_file(params.model.c_str(), mparams);
}
if (model == NULL) {
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.c_str());
return iparams;
}
@ -905,13 +946,13 @@ struct common_init_result common_init_from_params(common_params & params) {
llama_context * lctx = llama_init_from_model(model, cparams);
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str());
llama_model_free(model);
return iparams;
}
if (params.ctx_shift && !llama_kv_self_can_shift(lctx)) {
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) {
LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__);
params.ctx_shift = false;
}
@ -988,8 +1029,6 @@ struct common_init_result common_init_from_params(common_params & params) {
if (params.warmup) {
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
llama_set_warmup(lctx, true);
std::vector<llama_token> tmp;
llama_token bos = llama_vocab_bos(vocab);
llama_token eos = llama_vocab_eos(vocab);
@ -1017,10 +1056,9 @@ struct common_init_result common_init_from_params(common_params & params) {
if (llama_model_has_decoder(model)) {
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
}
llama_kv_self_clear(lctx);
llama_kv_cache_clear(lctx);
llama_synchronize(lctx);
llama_perf_context_reset(lctx);
llama_set_warmup(lctx, false);
}
iparams.model.reset(model);
@ -1029,19 +1067,6 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams;
}
std::string get_model_endpoint() {
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
const char * hf_endpoint_env = getenv("HF_ENDPOINT");
const char * endpoint_env = model_endpoint_env ? model_endpoint_env : hf_endpoint_env;
std::string model_endpoint = "https://huggingface.co/";
if (endpoint_env) {
model_endpoint = endpoint_env;
if (model_endpoint.back() != '/') model_endpoint += '/';
}
return model_endpoint;
}
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
llama_clear_adapter_lora(ctx);
for (auto & la : lora) {
@ -1057,18 +1082,15 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
if (!params.devices.empty()) {
mparams.devices = params.devices.data();
}
if (params.n_gpu_layers != -1) {
mparams.n_gpu_layers = params.n_gpu_layers;
}
mparams.main_gpu = params.main_gpu;
mparams.split_mode = params.split_mode;
mparams.tensor_split = params.tensor_split;
mparams.use_mmap = params.use_mmap;
mparams.use_mlock = params.use_mlock;
mparams.check_tensors = params.check_tensors;
if (params.kv_overrides.empty()) {
mparams.kv_overrides = NULL;
} else {
@ -1076,13 +1098,6 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
mparams.kv_overrides = params.kv_overrides.data();
}
if (params.tensor_buft_overrides.empty()) {
mparams.tensor_buft_overrides = NULL;
} else {
GGML_ASSERT(params.tensor_buft_overrides.back().pattern == nullptr && "Tensor buffer overrides not terminated with empty pattern");
mparams.tensor_buft_overrides = params.tensor_buft_overrides.data();
}
return mparams;
}
@ -1096,6 +1111,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.n_threads = params.cpuparams.n_threads;
cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
cparams.logits_all = params.logits_all;
cparams.embeddings = params.embedding;
cparams.rope_scaling_type = params.rope_scaling_type;
cparams.rope_freq_base = params.rope_freq_base;
@ -1113,7 +1129,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn;
cparams.no_perf = params.no_perf;
cparams.op_offload = !params.no_op_offload;
if (params.reranking) {
cparams.embeddings = true;
@ -1142,6 +1157,451 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p
return tpp;
}
#ifdef LLAMA_USE_CURL
#define CURL_MAX_RETRY 3
#define CURL_RETRY_DELAY_SECONDS 2
static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds) {
int remaining_attempts = max_attempts;
while (remaining_attempts > 0) {
LOG_INF("%s: Trying to download from %s (attempt %d of %d)...\n", __func__ , url.c_str(), max_attempts - remaining_attempts + 1, max_attempts);
CURLcode res = curl_easy_perform(curl);
if (res == CURLE_OK) {
return true;
}
int exponential_backoff_delay = std::pow(retry_delay_seconds, max_attempts - remaining_attempts) * 1000;
LOG_WRN("%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay);
remaining_attempts--;
std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
}
LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts);
return false;
}
static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
// Initialize libcurl
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
curl_slist_ptr http_headers;
if (!curl) {
LOG_ERR("%s: error initializing libcurl\n", __func__);
return false;
}
bool force_download = false;
// Set the URL, allow to follow http redirection
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
// Check if hf-token or bearer-token was specified
if (!hf_token.empty()) {
std::string auth_header = "Authorization: Bearer " + hf_token;
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
}
#if defined(_WIN32)
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
// operating system. Currently implemented under MS-Windows.
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
#endif
// Check if the file already exists locally
auto file_exists = std::filesystem::exists(path);
// If the file exists, check its JSON metadata companion file.
std::string metadata_path = path + ".json";
nlohmann::json metadata;
std::string etag;
std::string last_modified;
if (file_exists) {
// Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
std::ifstream metadata_in(metadata_path);
if (metadata_in.good()) {
try {
metadata_in >> metadata;
LOG_INF("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str());
if (metadata.contains("url") && metadata.at("url").is_string()) {
auto previous_url = metadata.at("url").get<std::string>();
if (previous_url != url) {
LOG_ERR("%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str());
return false;
}
}
if (metadata.contains("etag") && metadata.at("etag").is_string()) {
etag = metadata.at("etag");
}
if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
last_modified = metadata.at("lastModified");
}
} catch (const nlohmann::json::exception & e) {
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
return false;
}
}
} else {
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
}
// Send a HEAD request to retrieve the etag and last-modified headers
struct common_load_model_from_url_headers {
std::string etag;
std::string last_modified;
};
common_load_model_from_url_headers headers;
{
typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
static std::regex header_regex("([^:]+): (.*)\r\n");
static std::regex etag_regex("ETag", std::regex_constants::icase);
static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
std::string header(buffer, n_items);
std::smatch match;
if (std::regex_match(header, match, header_regex)) {
const std::string & key = match[1];
const std::string & value = match[2];
if (std::regex_match(key, match, etag_regex)) {
headers->etag = value;
} else if (std::regex_match(key, match, last_modified_regex)) {
headers->last_modified = value;
}
}
return n_items;
};
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress
curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
if (!was_perform_successful) {
return false;
}
long http_code = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
if (http_code != 200) {
// HEAD not supported, we don't know if the file has changed
// force trigger downloading
force_download = true;
LOG_ERR("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
}
}
bool should_download = !file_exists || force_download;
if (!should_download) {
if (!etag.empty() && etag != headers.etag) {
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str());
should_download = true;
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str());
should_download = true;
}
}
if (should_download) {
std::string path_temporary = path + ".downloadInProgress";
if (file_exists) {
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
if (remove(path.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
return false;
}
}
// Set the output file
struct FILE_deleter {
void operator()(FILE * f) const {
fclose(f);
}
};
std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "wb"));
if (!outfile) {
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str());
return false;
}
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
return fwrite(data, size, nmemb, (FILE *)fd);
};
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L);
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get());
// display download progress
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L);
// helper function to hide password in URL
auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
std::size_t protocol_pos = url.find("://");
if (protocol_pos == std::string::npos) {
return url; // Malformed URL
}
std::size_t at_pos = url.find('@', protocol_pos + 3);
if (at_pos == std::string::npos) {
return url; // No password in URL
}
return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos);
};
// start the download
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
if (!was_perform_successful) {
return false;
}
long http_code = 0;
curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
if (http_code < 200 || http_code >= 400) {
LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
return false;
}
// Causes file to be closed explicitly here before we rename it.
outfile.reset();
// Write the updated JSON metadata file.
metadata.update({
{"url", url},
{"etag", headers.etag},
{"lastModified", headers.last_modified}
});
std::ofstream(metadata_path) << metadata.dump(4);
LOG_INF("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
if (rename(path_temporary.c_str(), path.c_str()) != 0) {
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return false;
}
}
return true;
}
struct llama_model * common_load_model_from_url(
const std::string & model_url,
const std::string & local_path,
const std::string & hf_token,
const struct llama_model_params & params) {
// Basic validation of the model_url
if (model_url.empty()) {
LOG_ERR("%s: invalid model_url\n", __func__);
return NULL;
}
if (!common_download_file(model_url, local_path, hf_token)) {
return NULL;
}
// check for additional GGUFs split to download
int n_split = 0;
{
struct gguf_init_params gguf_params = {
/*.no_alloc = */ true,
/*.ctx = */ NULL,
};
auto * ctx_gguf = gguf_init_from_file(local_path.c_str(), gguf_params);
if (!ctx_gguf) {
LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, local_path.c_str());
return NULL;
}
auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT);
if (key_n_split >= 0) {
n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
}
gguf_free(ctx_gguf);
}
if (n_split > 1) {
char split_prefix[PATH_MAX] = {0};
char split_url_prefix[LLAMA_CURL_MAX_URL_LENGTH] = {0};
// Verify the first split file format
// and extract split URL and PATH prefixes
{
if (!llama_split_prefix(split_prefix, sizeof(split_prefix), local_path.c_str(), 0, n_split)) {
LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, local_path.c_str(), n_split);
return NULL;
}
if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model_url.c_str(), 0, n_split)) {
LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model_url.c_str(), n_split);
return NULL;
}
}
// Prepare download in parallel
std::vector<std::future<bool>> futures_download;
for (int idx = 1; idx < n_split; idx++) {
futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split, hf_token](int download_idx) -> bool {
char split_path[PATH_MAX] = {0};
llama_split_path(split_path, sizeof(split_path), split_prefix, download_idx, n_split);
char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0};
llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split);
return common_download_file(split_url, split_path, hf_token);
}, idx));
}
// Wait for all downloads to complete
for (auto & f : futures_download) {
if (!f.get()) {
return NULL;
}
}
}
return llama_model_load_from_file(local_path.c_str(), params);
}
struct llama_model * common_load_model_from_hf(
const std::string & repo,
const std::string & remote_path,
const std::string & local_path,
const std::string & hf_token,
const struct llama_model_params & params) {
// construct hugging face model url:
//
// --repo ggml-org/models --file tinyllama-1.1b/ggml-model-f16.gguf
// https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf
//
// --repo TheBloke/Mixtral-8x7B-v0.1-GGUF --file mixtral-8x7b-v0.1.Q4_K_M.gguf
// https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q4_K_M.gguf
//
std::string model_url = "https://huggingface.co/";
model_url += repo;
model_url += "/resolve/main/";
model_url += remote_path;
return common_load_model_from_url(model_url, local_path, hf_token, params);
}
/**
* Allow getting the HF file from the HF repo with tag (like ollama), for example:
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
* - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s
* Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo)
*
* Return pair of <repo, file> (with "repo" already having tag removed)
*
* Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
*/
std::pair<std::string, std::string> common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & hf_token) {
auto parts = string_split<std::string>(hf_repo_with_tag, ':');
std::string tag = parts.size() > 1 ? parts.back() : "latest";
std::string hf_repo = parts[0];
if (string_split<std::string>(hf_repo, '/').size() != 2) {
throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
}
// fetch model info from Hugging Face Hub API
json model_info;
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
curl_slist_ptr http_headers;
std::string res_str;
std::string url = "https://huggingface.co/v2/" + hf_repo + "/manifests/" + tag;
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
static_cast<std::string *>(data)->append((char * ) ptr, size * nmemb);
return size * nmemb;
};
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str);
#if defined(_WIN32)
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
#endif
if (!hf_token.empty()) {
std::string auth_header = "Authorization: Bearer " + hf_token;
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
}
// Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json");
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
CURLcode res = curl_easy_perform(curl.get());
if (res != CURLE_OK) {
throw std::runtime_error("error: cannot make GET request to HF API");
}
long res_code;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code);
if (res_code == 200) {
model_info = json::parse(res_str);
} else if (res_code == 401) {
throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
} else {
throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str()));
}
// check response
if (!model_info.contains("ggufFile")) {
throw std::runtime_error("error: model does not have ggufFile");
}
json & gguf_file = model_info.at("ggufFile");
if (!gguf_file.contains("rfilename")) {
throw std::runtime_error("error: ggufFile does not have rfilename");
}
return std::make_pair(hf_repo, gguf_file.at("rfilename"));
}
#else
struct llama_model * common_load_model_from_url(
const std::string & /*model_url*/,
const std::string & /*local_path*/,
const std::string & /*hf_token*/,
const struct llama_model_params & /*params*/) {
LOG_WRN("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
return nullptr;
}
struct llama_model * common_load_model_from_hf(
const std::string & /*repo*/,
const std::string & /*remote_path*/,
const std::string & /*local_path*/,
const std::string & /*hf_token*/,
const struct llama_model_params & /*params*/) {
LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
return nullptr;
}
std::pair<std::string, std::string> common_get_hf_file(const std::string &, const std::string &) {
LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
return std::make_pair("", "");
}
#endif // LLAMA_USE_CURL
//
// Batch utils
//
@ -1566,19 +2026,3 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
return result;
}
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride) {
const int64_t ne_datapoint = llama_n_ctx(ctx);
const int64_t ndata = (tokens.size() - ne_datapoint - 1) / stride;
ggml_opt_dataset_t result = ggml_opt_dataset_init(
GGML_TYPE_I32, GGML_TYPE_I32, ne_datapoint, ne_datapoint, ndata, /*ndata_shard =*/ 1);
llama_token * data = (llama_token *) ggml_opt_dataset_data(result)->data;
llama_token * labels = (llama_token *) ggml_opt_dataset_labels(result)->data;
for (int64_t idata = 0; idata < ndata; ++idata) {
memcpy(data + idata*ne_datapoint, tokens.data() + idata*stride + 0, ne_datapoint*sizeof(llama_token));
memcpy(labels + idata*ne_datapoint, tokens.data() + idata*stride + 1, ne_datapoint*sizeof(llama_token));
}
return result;
}

View File

@ -66,6 +66,7 @@ enum llama_example {
LLAMA_EXAMPLE_COMMON,
LLAMA_EXAMPLE_SPECULATIVE,
LLAMA_EXAMPLE_MAIN,
LLAMA_EXAMPLE_INFILL,
LLAMA_EXAMPLE_EMBEDDING,
LLAMA_EXAMPLE_PERPLEXITY,
LLAMA_EXAMPLE_RETRIEVAL,
@ -95,7 +96,6 @@ enum common_sampler_type {
COMMON_SAMPLER_TYPE_XTC = 8,
COMMON_SAMPLER_TYPE_INFILL = 9,
COMMON_SAMPLER_TYPE_PENALTIES = 10,
COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11,
};
// dimensionality reduction methods, used by cvector-generator
@ -110,17 +110,9 @@ enum common_conversation_mode {
COMMON_CONVERSATION_MODE_AUTO = 2,
};
enum common_grammar_trigger_type {
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
};
struct common_grammar_trigger {
common_grammar_trigger_type type;
std::string value;
llama_token token = LLAMA_TOKEN_NULL;
std::string word;
bool at_start;
};
// sampling parameters
@ -161,7 +153,6 @@ struct common_params_sampling {
std::vector<enum common_sampler_type> samplers = {
COMMON_SAMPLER_TYPE_PENALTIES,
COMMON_SAMPLER_TYPE_DRY,
COMMON_SAMPLER_TYPE_TOP_N_SIGMA,
COMMON_SAMPLER_TYPE_TOP_K,
COMMON_SAMPLER_TYPE_TYPICAL_P,
COMMON_SAMPLER_TYPE_TOP_P,
@ -172,7 +163,8 @@ struct common_params_sampling {
std::string grammar; // optional BNF-like grammar to constrain sampling
bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
std::vector<common_grammar_trigger> grammar_trigger_words; // optional trigger words to trigger lazy grammar
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens.
std::set<llama_token> preserved_tokens;
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
@ -181,13 +173,6 @@ struct common_params_sampling {
std::string print() const;
};
struct common_params_model {
std::string path = ""; // model local path // NOLINT
std::string url = ""; // model url to download // NOLINT
std::string hf_repo = ""; // HF repo // NOLINT
std::string hf_file = ""; // HF file // NOLINT
};
struct common_params_speculative {
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
@ -201,13 +186,19 @@ struct common_params_speculative {
struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;
struct common_params_model model;
std::string hf_repo = ""; // HF repo // NOLINT
std::string hf_file = ""; // HF file // NOLINT
std::string model = ""; // draft model for speculative decoding // NOLINT
std::string model_url = ""; // model url to download // NOLINT
};
struct common_params_vocoder {
struct common_params_model model;
std::string hf_repo = ""; // HF repo // NOLINT
std::string hf_file = ""; // HF file // NOLINT
std::string speaker_file = ""; // speaker file path // NOLINT
std::string model = ""; // model path // NOLINT
std::string model_url = ""; // model url to download // NOLINT
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
};
@ -263,12 +254,13 @@ struct common_params {
struct common_params_speculative speculative;
struct common_params_vocoder vocoder;
struct common_params_model model;
std::string model = ""; // model path // NOLINT
std::string model_alias = ""; // model alias // NOLINT
std::string model_url = ""; // model url to download // NOLINT
std::string hf_token = ""; // HF token // NOLINT
std::string hf_repo = ""; // HF repo // NOLINT
std::string hf_file = ""; // HF file // NOLINT
std::string prompt = ""; // NOLINT
std::string system_prompt = ""; // NOLINT
std::string prompt_file = ""; // store the external prompt file name // NOLINT
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT
std::string input_prefix = ""; // string to prefix user inputs with // NOLINT
@ -280,7 +272,6 @@ struct common_params {
std::vector<std::string> in_files; // all input files
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
std::vector<llama_model_kv_override> kv_overrides;
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_adapter_lora_apply)
std::vector<common_adapter_lora_info> lora_adapters; // lora adapter path with user defined scale
@ -324,6 +315,7 @@ struct common_params {
bool ctx_shift = true; // context shift on inifinite text generation
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool logits_all = false; // return logits for all tokens in the batch
bool use_mmap = true; // use mmap for faster loads
bool use_mlock = false; // use mlock to keep model in memory
bool verbose_prompt = false; // print prompt tokens before generation
@ -332,19 +324,14 @@ struct common_params {
bool no_kv_offload = false; // disable KV offloading
bool warmup = true; // warmup run
bool check_tensors = false; // validate tensor data
bool no_op_offload = false; // globally disable offload host tensor operations to device
bool single_turn = false; // single turn chat conversation
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO;
// multimodal models (see tools/mtmd)
struct common_params_model mmproj;
bool mmproj_use_gpu = true; // use GPU for multimodal model
bool no_mmproj = false; // explicitly disable multimodal model
// multimodal models (see examples/llava)
std::string mmproj = ""; // path to multimodal projector // NOLINT
std::vector<std::string> image; // path to image file(s)
// embedding
@ -404,28 +391,29 @@ struct common_params {
int32_t i_pos = -1; // position of the passkey in the junk text
// imatrix params
std::string out_file = "imatrix.dat"; // save the resulting imatrix to this file
int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations
int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations
int32_t i_chunk = 0; // start processing from this chunk
bool process_output = false; // collect data for the output tensor
bool compute_ppl = true; // whether to compute perplexity
bool parse_special = false; // whether to parse special tokens during imatrix tokenization
// cvector-generator params
int n_pca_batch = 100;
int n_pca_iterations = 1000;
dimre_method cvector_dimre_method = DIMRE_METHOD_PCA;
std::string cvector_positive_file = "tools/cvector-generator/positive.txt";
std::string cvector_negative_file = "tools/cvector-generator/negative.txt";
std::string cvector_outfile = "control_vector.gguf";
std::string cvector_positive_file = "examples/cvector-generator/positive.txt";
std::string cvector_negative_file = "examples/cvector-generator/negative.txt";
bool spm_infill = false; // suffix/prefix/middle pattern for infill
std::string lora_outfile = "ggml-lora-merged-f16.gguf";
// batched-bench params
bool batched_bench_output_jsonl = false;
// common params
std::string out_file; // output filename for all example programs
};
// call once at the start of a program if it uses libcommon
@ -465,8 +453,6 @@ std::string string_repeat(const std::string & str, size_t n);
void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
std::string regex_escape(const std::string & s);
template<class T>
static std::vector<T> string_split(const std::string & str, char delim) {
static_assert(!std::is_same<T, std::string>::value, "Please use the specialized version for std::string");
@ -544,11 +530,26 @@ struct llama_model_params common_model_params_to_llama ( common_params
struct llama_context_params common_context_params_to_llama(const common_params & params);
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);
struct llama_model * common_load_model_from_url(
const std::string & model_url,
const std::string & local_path,
const std::string & hf_token,
const struct llama_model_params & params);
struct llama_model * common_load_model_from_hf(
const std::string & repo,
const std::string & remote_path,
const std::string & local_path,
const std::string & hf_token,
const struct llama_model_params & params);
std::pair<std::string, std::string> common_get_hf_file(
const std::string & hf_repo_with_tag,
const std::string & hf_token);
// clear LoRA adapters from context, then apply new list of adapters
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
std::string get_model_endpoint();
//
// Batch utils
//
@ -666,9 +667,3 @@ const char * const LLM_KV_SPLIT_COUNT = "split.count";
const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
}
//
// training utils
//
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);

View File

@ -16,9 +16,6 @@ using json = nlohmann::ordered_json;
static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
auto has_max = max_items != std::numeric_limits<int>::max();
if (max_items == 0) {
return "";
}
if (min_items == 0 && max_items == 1) {
return item_rule + "?";
}
@ -267,7 +264,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
throw std::runtime_error("At least one of min_value or max_value must be set");
}
const std::string SPACE_RULE = "| \" \" | \"\\n\"{1,2} [ \\t]{0,20}";
const std::string SPACE_RULE = "| \" \" | \"\\n\" [ \\t]{0,20}";
struct BuiltinRule {
std::string content;
@ -767,10 +764,11 @@ private:
public:
SchemaConverter(
const std::function<json(const std::string &)> & fetch_json,
bool dotall)
bool dotall,
bool compact_spaces)
: _fetch_json(fetch_json), _dotall(dotall)
{
_rules["space"] = SPACE_RULE;
_rules["space"] = compact_spaces ? "\" \"?" : SPACE_RULE;
}
void resolve_refs(json & schema, const std::string & url) {
@ -1009,7 +1007,7 @@ std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
}
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall);
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall, options.compact_spaces);
common_grammar_builder builder {
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
return converter._add_rule(name, rule);

View File

@ -16,6 +16,7 @@ struct common_grammar_builder {
struct common_grammar_options {
bool dotall = false;
bool compact_spaces = false;
};
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});

View File

@ -1,11 +1,9 @@
#include "sampling.h"
#include "common.h"
#include "log.h"
#include <cmath>
#include <unordered_map>
#include <algorithm>
// the ring buffer works similarly to std::deque, but with a fixed capacity
// TODO: deduplicate with llama-impl.h
@ -161,57 +159,17 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE
} else {
std::vector<std::string> patterns_at_start;
std::vector<std::string> patterns_anywhere;
std::vector<llama_token> trigger_tokens;
for (const auto & trigger : params.grammar_triggers) {
switch (trigger.type) {
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
{
const auto & word = trigger.value;
patterns_anywhere.push_back(regex_escape(word));
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
{
const auto & pattern = trigger.value;
(trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
{
const auto token = trigger.token;
trigger_tokens.push_back(token);
break;
}
default:
GGML_ASSERT(false && "unknown trigger type");
}
}
std::vector<std::string> trigger_patterns;
if (!patterns_at_start.empty()) {
trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*");
}
if (!patterns_anywhere.empty()) {
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
}
std::vector<const char *> trigger_patterns_c;
trigger_patterns_c.reserve(trigger_patterns.size());
for (const auto & regex : trigger_patterns) {
trigger_patterns_c.push_back(regex.c_str());
std::vector<const char *> trigger_words;
trigger_words.reserve(params.grammar_trigger_words.size());
for (const auto & str : params.grammar_trigger_words) {
trigger_words.push_back(str.word.c_str());
}
grmr = params.grammar_lazy
? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
trigger_patterns_c.data(), trigger_patterns_c.size(),
trigger_tokens.data(), trigger_tokens.size())
? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root",
trigger_words.data(), trigger_words.size(),
params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size())
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
if (!grmr) {
return nullptr;
}
}
auto * result = new common_sampler {
@ -230,48 +188,51 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
params.logit_bias.data()));
if (params.mirostat == 0) {
for (const auto & cnstr : params.samplers) {
switch (cnstr) {
case COMMON_SAMPLER_TYPE_DRY:
{
std::vector<const char *> c_breakers;
c_breakers.reserve(params.dry_sequence_breakers.size());
for (const auto & str : params.dry_sequence_breakers) {
c_breakers.push_back(str.c_str());
}
if (params.top_n_sigma >= 0) {
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
llama_sampler_chain_add(result->chain, llama_sampler_init_temp (params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
} else {
for (const auto & cnstr : params.samplers) {
switch (cnstr) {
case COMMON_SAMPLER_TYPE_DRY:
{
std::vector<const char *> c_breakers;
c_breakers.reserve(params.dry_sequence_breakers.size());
for (const auto & str : params.dry_sequence_breakers) {
c_breakers.push_back(str.c_str());
}
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
}
break;
case COMMON_SAMPLER_TYPE_TOP_K:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
break;
case COMMON_SAMPLER_TYPE_TOP_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
break;
case COMMON_SAMPLER_TYPE_MIN_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_XTC:
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
break;
case COMMON_SAMPLER_TYPE_TYPICAL_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break;
case COMMON_SAMPLER_TYPE_INFILL:
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
break;
case COMMON_SAMPLER_TYPE_PENALTIES:
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
break;
default:
GGML_ASSERT(false && "unknown sampler type");
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
}
break;
case COMMON_SAMPLER_TYPE_TOP_K:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
break;
case COMMON_SAMPLER_TYPE_TOP_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_MIN_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_XTC:
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
break;
case COMMON_SAMPLER_TYPE_TYPICAL_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break;
case COMMON_SAMPLER_TYPE_INFILL:
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
break;
case COMMON_SAMPLER_TYPE_PENALTIES:
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
break;
default:
GGML_ASSERT(false && "unknown sampler type");
}
}
}
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
@ -473,7 +434,6 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's';
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
case COMMON_SAMPLER_TYPE_XTC: return 'x';
@ -489,7 +449,6 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma";
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
@ -504,7 +463,6 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
{ "dry", COMMON_SAMPLER_TYPE_DRY },
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
{ "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
@ -518,7 +476,6 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
{ "top-k", COMMON_SAMPLER_TYPE_TOP_K },
{ "top-p", COMMON_SAMPLER_TYPE_TOP_P },
{ "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
{ "nucleus", COMMON_SAMPLER_TYPE_TOP_P },
{ "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "typical", COMMON_SAMPLER_TYPE_TYPICAL_P },
@ -535,16 +492,14 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
auto sampler = sampler_canonical_name_map.find(name);
if (sampler != sampler_canonical_name_map.end()) {
samplers.push_back(sampler->second);
continue;
}
if (allow_alt_names) {
sampler = sampler_alt_name_map.find(name);
if (sampler != sampler_alt_name_map.end()) {
samplers.push_back(sampler->second);
continue;
} else {
if (allow_alt_names) {
sampler = sampler_alt_name_map.find(name);
if (sampler != sampler_alt_name_map.end()) {
samplers.push_back(sampler->second);
}
}
}
LOG_WRN("%s: unable to match sampler by name '%s'\n", __func__, name.c_str());
}
return samplers;
@ -556,7 +511,6 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
@ -571,8 +525,6 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
const auto sampler = sampler_name_map.find(c);
if (sampler != sampler_name_map.end()) {
samplers.push_back(sampler->second);
} else {
LOG_WRN("%s: unable to match sampler by char '%c'\n", __func__, c);
}
}

3032
llama/llama.cpp/examples/llava/clip.cpp vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,6 @@
#ifndef CLIP_H
#define CLIP_H
#include "ggml.h"
#include <stddef.h>
#include <stdint.h>
@ -30,28 +29,27 @@ struct clip_image_size {
int height;
};
struct clip_image_f32;
struct clip_image_u8_batch;
struct clip_image_f32_batch;
struct clip_context_params {
bool use_gpu;
enum ggml_log_level verbosity;
struct clip_image_u8_batch {
struct clip_image_u8 * data;
size_t size;
};
// deprecated, use clip_init
CLIP_API struct clip_ctx * clip_model_load(const char * fname, int verbosity);
struct clip_image_f32_batch {
struct clip_image_f32 * data;
size_t size;
};
CLIP_API struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_params);
CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity);
CLIP_API struct clip_ctx * clip_model_load_cpu(const char * fname, int verbosity);
CLIP_API void clip_free(struct clip_ctx * ctx);
CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx);
CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_w, int img_h);
CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w);
CLIP_API int32_t clip_get_image_size (const struct clip_ctx * ctx);
CLIP_API int32_t clip_get_patch_size (const struct clip_ctx * ctx);
CLIP_API int32_t clip_get_hidden_size(const struct clip_ctx * ctx);
CLIP_API int32_t clip_image_size (const struct clip_ctx * ctx);
CLIP_API int32_t clip_patch_size (const struct clip_ctx * ctx);
CLIP_API int32_t clip_hidden_size(const struct clip_ctx * ctx);
// TODO: should be enum, not string
CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);
@ -59,49 +57,24 @@ CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);
CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx);
CLIP_API size_t get_clip_image_grid_size(const struct clip_ctx * ctx);
GGML_DEPRECATED(CLIP_API int clip_n_patches(const struct clip_ctx * ctx),
"use clip_n_output_tokens instead");
GGML_DEPRECATED(CLIP_API int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * img),
"use clip_n_output_tokens instead");
CLIP_API int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img);
// for M-RoPE, this will be the number of token positions in X and Y directions
// for other models, X will be the total number of tokens and Y will be 1
CLIP_API int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img);
CLIP_API int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img);
// this should be equal to the embedding dimension of the text model
CLIP_API int clip_n_mmproj_embd(const struct clip_ctx * ctx);
CLIP_API int clip_n_patches (const struct clip_ctx * ctx);
CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img);
CLIP_API int clip_n_mmproj_embd (const struct clip_ctx * ctx);
CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip);
CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size);
CLIP_API struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip);
CLIP_API struct clip_image_size * clip_image_size_init(void);
CLIP_API struct clip_image_u8 * clip_image_u8_init (void);
CLIP_API struct clip_image_f32 * clip_image_f32_init(void);
CLIP_API struct clip_image_f32_batch * clip_image_f32_batch_init(void); // only used by libllava
CLIP_API struct clip_image_size * clip_image_size_init();
CLIP_API struct clip_image_u8 * clip_image_u8_init ();
CLIP_API struct clip_image_f32 * clip_image_f32_init();
// nx, ny are the output image dimensions
CLIP_API unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny);
CLIP_API void clip_image_size_free (struct clip_image_size * img_size);
CLIP_API void clip_image_u8_free (struct clip_image_u8 * img);
CLIP_API void clip_image_f32_free(struct clip_image_f32 * img);
CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch);
CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
// use for accessing underlay data of clip_image_f32_batch
CLIP_API size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch); // equivalent to batch->size()
CLIP_API size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->nx
CLIP_API size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->ny
CLIP_API struct clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data
/**
* Build image from pixels decoded by other libraries instead of stb_image.h for better performance.
* The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes
*/
/** build image from pixels decoded by other libraries instead of stb_image.h for better performance. The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes */
CLIP_API void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img);
CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
@ -122,8 +95,8 @@ CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out
CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx);
CLIP_API bool clip_is_glm(const struct clip_ctx * ctx);
CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx);
CLIP_API bool clip_is_llava(const struct clip_ctx * ctx);
CLIP_API bool clip_is_gemma3(const struct clip_ctx * ctx);
CLIP_API int get_deepest_feature_layer(const struct clip_ctx * ctx);
CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);

View File

@ -2,7 +2,6 @@
#include "llava.h"
#include "llama.h"
#include "ggml-cpp.h"
#include <algorithm>
#include <cerrno>
@ -11,7 +10,6 @@
#include <cstring>
#include <limits>
#include <vector>
#include <memory>
#if defined(LLAVA_LOG_OFF)
# define LOG_INF(...)
@ -47,17 +45,6 @@ struct clip_image_grid_shape {
int second;
};
// convenience cpp wrapper
struct clip_image_f32_batch_deleter {
void operator()(clip_image_f32_batch * val) { clip_image_f32_batch_free(val); }
};
typedef std::unique_ptr<clip_image_f32_batch, clip_image_f32_batch_deleter> clip_image_f32_batch_ptr;
struct clip_image_size_deleter {
void operator()(clip_image_f32_batch * val) { clip_image_f32_batch_free(val); }
};
typedef std::unique_ptr<clip_image_size, clip_image_size_deleter> clip_image_size_ptr;
/**
* Selects the best resolution from a list of possible resolutions based on the original size.
*
@ -113,13 +100,13 @@ static struct clip_image_grid_shape get_anyres_image_grid_shape(const std::pair<
}
// Take the image segments in a grid configuration and return the embeddings and the number of embeddings into preallocated memory (image_embd_out)
static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *> & image_embd_v, struct clip_image_grid_shape grid_shape, float * image_embd_out, int * n_img_pos_out, clip_image_f32 * img_input) {
static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *> & image_embd_v, struct clip_image_grid_shape grid_shape, float * image_embd_out, int * n_img_pos_out) {
struct {
struct ggml_context * ctx;
} model;
const int32_t image_size = clip_get_image_size(ctx_clip);
const int32_t patch_size = clip_get_patch_size(ctx_clip);
const int32_t image_size = clip_image_size(ctx_clip);
const int32_t patch_size = clip_patch_size(ctx_clip);
int32_t num_patches_per_side = image_size / patch_size; // 336 / 14 = 24 - used for embedding-patching boxes (24*24 = 576 patches)
@ -176,7 +163,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>
model.ctx = ggml_init(params);
struct ggml_tensor * image_features = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, clip_n_mmproj_embd(ctx_clip), clip_n_output_tokens(ctx_clip, img_input), num_images - 1); // example: 4096 x 576 x 4
struct ggml_tensor * image_features = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, clip_n_mmproj_embd(ctx_clip), clip_n_patches(ctx_clip), num_images - 1); // example: 4096 x 576 x 4
// ggml_tensor_printf(image_features,"image_features",__LINE__,false,false);
// fill it with the image embeddings, ignoring the base
for (size_t i = 1; i < num_images; i++) {
@ -210,17 +197,13 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>
struct ggml_tensor *flatten = ggml_view_2d(model.ctx, permuted_cont, clip_n_mmproj_embd(ctx_clip), num_patches_height * num_patches_width * num_patches_per_side * num_patches_per_side, size_ele * clip_n_mmproj_embd(ctx_clip), 0);
// ggml_tensor_printf(flatten,"flatten",__LINE__,false,false);
ggml_build_forward_expand(gf, flatten);
ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
GGML_ASSERT(backend != nullptr && "failed to initialize CPU backend");
ggml_backend_graph_compute(backend.get(), gf);
ggml_graph_compute_with_ctx(model.ctx, gf, 1);
struct ggml_tensor* result = ggml_graph_node(gf, -1);
memcpy(image_embd_out, image_embd_v[0], clip_embd_nbytes(ctx_clip)); // main image as global context
// append without newline tokens (default behavior in llava_arch when not using unpad ):
memcpy(image_embd_out + clip_n_output_tokens(ctx_clip, img_input) * clip_n_mmproj_embd(ctx_clip), (float*)result->data, clip_embd_nbytes(ctx_clip) * (num_images-1)); // grid patches
*n_img_pos_out = static_cast<int>(result->ne[1]+clip_n_output_tokens(ctx_clip, img_input));
memcpy(image_embd_out + clip_n_patches(ctx_clip) * clip_n_mmproj_embd(ctx_clip), (float*)result->data, clip_embd_nbytes(ctx_clip) * (num_images-1)); // grid patches
*n_img_pos_out = static_cast<int>(result->ne[1]+clip_n_patches(ctx_clip));
// Debug: Test single segments
// Current findings: sending base image, sending a segment embedding all works similar to python
@ -263,9 +246,12 @@ static clip_image_f32 * reshape_by_patch(clip_image_f32 * image, int patch_size)
static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos) {
// std::vector<clip_image_f32*> img_res_v; // format VectN x H x W x RGB (N x 336 x 336 x 3), so interleaved RGB - different to the python implementation which is N x 3 x 336 x 336
clip_image_f32_batch_ptr img_res_v(clip_image_f32_batch_init());
if (!clip_image_preprocess(ctx_clip, img, img_res_v.get())) {
clip_image_f32_batch img_res_v;
img_res_v.size = 0;
img_res_v.data = nullptr;
if (!clip_image_preprocess(ctx_clip, img, &img_res_v)) {
LOG_ERR("%s: unable to preprocess image\n", __func__);
delete[] img_res_v.data;
return false;
}
@ -273,72 +259,66 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
const char * mm_patch_merge_type = clip_patch_merge_type(ctx_clip);
const size_t n_imgs = clip_image_f32_batch_n_images(img_res_v.get());
if (clip_is_minicpmv(ctx_clip) || clip_is_qwen2vl(ctx_clip)) {
std::vector<float *> image_embd_v;
image_embd_v.resize(n_imgs);
clip_image_size load_image_size;
image_embd_v.resize(img_res_v.size);
struct clip_image_size * load_image_size = clip_image_size_init();
for (size_t i = 0; i < n_imgs; i++) {
for (size_t i = 0; i < img_res_v.size; i++) {
const int64_t t_img_enc_step_start_us = ggml_time_us();
int nx = clip_image_f32_batch_nx(img_res_v.get(), i);
int ny = clip_image_f32_batch_ny(img_res_v.get(), i);
image_embd_v[i] = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, nx, ny));
int patch_size = 14;
load_image_size.width = nx;
load_image_size.height = ny;
clip_add_load_image_size(ctx_clip, &load_image_size);
image_embd_v[i] = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, img_res_v.data[i].nx, img_res_v.data[i].ny));
int patch_size=14;
load_image_size->width = img_res_v.data[i].nx;
load_image_size->height = img_res_v.data[i].ny;
clip_add_load_image_size(ctx_clip, load_image_size);
bool encoded = false;
clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i);
if (clip_is_qwen2vl(ctx_clip)) {
encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd_v[i]);
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
}
else {
encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(img_res, patch_size), image_embd_v[i]);
encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
}
if (!encoded) {
LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) n_imgs);
LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size);
return false;
}
const int64_t t_img_enc_steop_batch_us = ggml_time_us();
LOG_INF("%s: step %d of %d encoded in %8.2f ms\n", __func__, (int)i+1, (int)n_imgs, (t_img_enc_steop_batch_us - t_img_enc_step_start_us) / 1000.0);
LOG_INF("%s: step %d of %d encoded in %8.2f ms\n", __func__, (int)i+1, (int)img_res_v.size, (t_img_enc_steop_batch_us - t_img_enc_step_start_us) / 1000.0);
}
const int64_t t_img_enc_batch_us = ggml_time_us();
LOG_INF("%s: all %d segments encoded in %8.2f ms\n", __func__, (int)n_imgs, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0);
LOG_INF("%s: all %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0);
int n_img_pos_out = 0;
for (size_t i = 0; i < image_embd_v.size(); i++) {
int nx = clip_image_f32_batch_nx(img_res_v.get(), i);
int ny = clip_image_f32_batch_ny(img_res_v.get(), i);
clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i);
std::memcpy(
image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip),
image_embd_v[i],
clip_embd_nbytes_by_img(ctx_clip, nx, ny));
n_img_pos_out += clip_n_output_tokens(ctx_clip, img_res);
clip_embd_nbytes_by_img(ctx_clip, img_res_v.data[i].nx, img_res_v.data[i].ny));
n_img_pos_out += clip_n_patches_by_img(ctx_clip, &img_res_v.data[i]);
}
*n_img_pos = n_img_pos_out;
for (size_t i = 0; i < image_embd_v.size(); i++) {
free(image_embd_v[i]);
}
image_embd_v.clear();
load_image_size.width = img->nx;
load_image_size.height = img->ny;
clip_add_load_image_size(ctx_clip, &load_image_size);
LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size.width, load_image_size.height);
load_image_size->width = img->nx;
load_image_size->height = img->ny;
clip_add_load_image_size(ctx_clip, load_image_size);
LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height);
delete[] img_res_v.data;
img_res_v.size = 0;
img_res_v.data = nullptr;
}
else if (clip_is_glm(ctx_clip)){
struct clip_image_size * load_image_size = clip_image_size_init();
load_image_size->width = clip_image_f32_batch_nx(img_res_v.get(), 0);
load_image_size->height = clip_image_f32_batch_ny(img_res_v.get(), 0);
load_image_size->width = img_res_v.data[0].nx;
load_image_size->height = img_res_v.data[0].ny;
clip_add_load_image_size(ctx_clip, load_image_size);
clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), 0);
bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd);
int pos = int(load_image_size->width/clip_get_patch_size(ctx_clip)/2);
bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd);
int pos = int(load_image_size->width/clip_patch_size(ctx_clip)/2);
*n_img_pos = (pos * pos + 2);
if (!encoded){
LOG_ERR("Unable to encode image \n");
@ -347,9 +327,9 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
}
else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) {
// flat / default llava-1.5 type embedding
clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), 0);
*n_img_pos = clip_n_output_tokens(ctx_clip, img_res);
bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd); // image_embd shape is 576 x 4096
*n_img_pos = clip_n_patches(ctx_clip);
bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd); // image_embd shape is 576 x 4096
delete[] img_res_v.data;
if (!encoded) {
LOG_ERR("Unable to encode image\n");
@ -360,18 +340,17 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
// spatial_unpad llava-1.6 type embedding
// TODO: CLIP needs batching support - in HF the llm projection is separate after encoding, which might be a solution to quickly get batching working
std::vector<float *> image_embd_v;
image_embd_v.resize(n_imgs);
for (size_t i = 0; i < n_imgs; i++) {
clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i);
image_embd_v.resize(img_res_v.size);
for (size_t i = 0; i < img_res_v.size; i++) {
image_embd_v[i] = (float *)malloc(clip_embd_nbytes(ctx_clip)); // 576 patches * 4096 embeddings * 4 bytes = 9437184
const bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside
const bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside
if (!encoded) {
LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) n_imgs);
LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size);
return false;
}
}
const int64_t t_img_enc_batch_us = ggml_time_us();
LOG_INF("%s: %d segments encoded in %8.2f ms\n", __func__, (int)n_imgs, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0);
LOG_INF("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0);
const int32_t * image_grid = clip_image_grid(ctx_clip);
const size_t num_gridpoints = get_clip_image_grid_size(ctx_clip);
@ -381,13 +360,17 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
grid_pinpoints.push_back({image_grid[i], image_grid[i+1]});
}
const int32_t image_size = clip_get_image_size(ctx_clip);
// free all img_res_v - not needed anymore
delete[] img_res_v.data;
img_res_v.size = 0;
img_res_v.data = nullptr;
const int32_t image_size = clip_image_size(ctx_clip);
struct clip_image_grid_shape grid_shape = get_anyres_image_grid_shape({img->nx,img->ny}, grid_pinpoints, image_size);
int n_img_pos_out;
clip_image_f32 * img_input = clip_image_f32_get_img(img_res_v.get(), 0);
clip_llava_handle_patches(ctx_clip, image_embd_v, grid_shape, image_embd, &n_img_pos_out, img_input);
clip_llava_handle_patches(ctx_clip, image_embd_v, grid_shape, image_embd, &n_img_pos_out);
*n_img_pos = n_img_pos_out;
for (size_t i = 0; i < image_embd_v.size(); i++) {
@ -462,7 +445,7 @@ struct llava_embd_batch {
std::vector<llama_seq_id *> seq_ids;
std::vector<int8_t> logits;
llama_batch batch;
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
llava_embd_batch(float * embd, int32_t n_embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
pos .resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids .resize(n_tokens + 1);
@ -474,6 +457,7 @@ struct llava_embd_batch {
/*n_tokens =*/ n_tokens,
/*tokens =*/ nullptr,
/*embd =*/ embd,
/*n_embd =*/ n_embd,
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
@ -497,7 +481,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
n_eval = n_batch;
}
float * embd = image_embed->embed+i*n_embd;
llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0);
llava_embd_batch llava_batch = llava_embd_batch(embd, n_embd, n_eval, *n_past, 0);
if (llama_decode(ctx_llama, llava_batch.batch)) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;

View File

@ -1,4 +1,4 @@
package mtmd
package llava
// #cgo CXXFLAGS: -std=c++11
// #cgo CPPFLAGS: -I${SRCDIR}/../../include -I${SRCDIR}/../../common

View File

@ -4,7 +4,6 @@
#include "ggml.h"
#include "ggml-cpu.h"
#include "ggml-backend.h"
#include "ggml-opt.h"
#include <stddef.h>
#include <stdint.h>
@ -61,7 +60,6 @@ extern "C" {
struct llama_model;
struct llama_context;
struct llama_sampler;
struct llama_kv_cache;
typedef int32_t llama_pos;
typedef int32_t llama_token;
@ -108,12 +106,6 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
};
enum llama_rope_type {
@ -258,6 +250,7 @@ extern "C" {
llama_token * token;
float * embd;
int32_t n_embd;
llama_pos * pos;
int32_t * n_seq_id;
llama_seq_id ** seq_id;
@ -284,18 +277,10 @@ extern "C" {
};
};
struct llama_model_tensor_buft_override {
const char * pattern;
ggml_backend_buffer_type_t buft;
};
struct llama_model_params {
// NULL-terminated list of devices to use for offloading (if NULL, all available devices are used)
ggml_backend_dev_t * devices;
// NULL-terminated list of buffer types to use for tensors that match a pattern
const struct llama_model_tensor_buft_override * tensor_buft_overrides;
int32_t n_gpu_layers; // number of layers to store in VRAM
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
@ -353,34 +338,35 @@ extern "C" {
enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
// TODO: move at the end of the struct
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
bool no_perf; // whether to measure performance timings
bool cross_attn; // whether to use cross attention
// Abort callback
// if it returns true, execution of llama_decode() will be aborted
// currently works only with CPU execution
ggml_abort_callback abort_callback;
void * abort_callback_data;
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
bool no_perf; // whether to measure performance timings
bool op_offload; // whether to offload host tensor operations to device
};
// model quantization parameters
typedef struct llama_model_quantize_params {
int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
enum llama_ftype ftype; // quantize to this llama_ftype
enum ggml_type output_tensor_type; // output tensor type
enum ggml_type token_embedding_type; // token embeddings tensor type
bool allow_requantize; // allow quantizing non-f32/f16 tensors
bool quantize_output_tensor; // quantize output.weight
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
bool pure; // quantize all tensors to the default type
bool keep_split; // quantize to the same number of shards
void * imatrix; // pointer to importance matrix data
void * kv_overrides; // pointer to vector containing overrides
void * tensor_types; // pointer to vector containing tensor types
int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
enum llama_ftype ftype; // quantize to this llama_ftype
enum ggml_type output_tensor_type; // output tensor type
enum ggml_type token_embedding_type; // token embeddings tensor type
bool allow_requantize; // allow quantizing non-f32/f16 tensors
bool quantize_output_tensor; // quantize output.weight
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
bool pure; // quantize all tensors to the default type
bool keep_split; // quantize to the same number of shards
void * imatrix; // pointer to importance matrix data
void * kv_overrides; // pointer to vector containing overrides
} llama_model_quantize_params;
typedef struct llama_logit_bias {
@ -446,10 +432,6 @@ extern "C" {
size_t n_paths,
struct llama_model_params params);
LLAMA_API void llama_model_save_to_file(
const struct llama_model * model,
const char * path_model);
DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model),
"use llama_model_free instead");
@ -464,6 +446,10 @@ extern "C" {
struct llama_context_params params),
"use llama_init_from_model instead");
// TODO (jmorganca): this should most likely be passed in as part of a batch
// and not set on the context for all batches.
LLAMA_API void llama_set_cross_attention(struct llama_context * ctx, bool cross_attn_state);
// Frees all allocated memory
LLAMA_API void llama_free(struct llama_context * ctx);
@ -489,8 +475,7 @@ extern "C" {
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx);
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
@ -607,7 +592,7 @@ extern "C" {
// KV cache
//
// TODO: start using struct llama_kv_cache
// TODO: remove llama_kv_cache_view_* API
// Information associated with an individual cell in the KV cache view.
struct llama_kv_cache_view_cell {
@ -662,19 +647,13 @@ extern "C" {
// Returns the number of tokens in the KV cache (slow, use only for debug)
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx);
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx),
"use llama_kv_self_n_tokens instead");
LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx);
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx),
"use llama_kv_self_used_cells instead");
LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
// Clear the KV cache - both cell info is erased and KV data is zeroed
LLAMA_API void llama_kv_self_clear(
LLAMA_API void llama_kv_cache_clear(
struct llama_context * ctx);
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
@ -682,7 +661,7 @@ extern "C" {
// seq_id < 0 : match any sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API bool llama_kv_self_seq_rm(
LLAMA_API bool llama_kv_cache_seq_rm(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
@ -692,7 +671,7 @@ extern "C" {
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_self_seq_cp(
LLAMA_API void llama_kv_cache_seq_cp(
struct llama_context * ctx,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
@ -700,17 +679,17 @@ extern "C" {
llama_pos p1);
// Removes all tokens that do not belong to the specified sequence
LLAMA_API void llama_kv_self_seq_keep(
LLAMA_API void llama_kv_cache_seq_keep(
struct llama_context * ctx,
llama_seq_id seq_id);
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
// If the KV cache is RoPEd, the KV data is updated accordingly:
// - lazily on next llama_decode()
// - explicitly with llama_kv_self_update()
// - explicitly with llama_kv_cache_update()
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_self_seq_add(
LLAMA_API void llama_kv_cache_seq_add(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
@ -720,10 +699,10 @@ extern "C" {
// Integer division of the positions by factor of `d > 1`
// If the KV cache is RoPEd, the KV data is updated accordingly:
// - lazily on next llama_decode()
// - explicitly with llama_kv_self_update()
// - explicitly with llama_kv_cache_update()
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_self_seq_div(
LLAMA_API void llama_kv_cache_seq_div(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
@ -731,76 +710,24 @@ extern "C" {
int d);
// Returns the largest position present in the KV cache for the specified sequence
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
struct llama_context * ctx,
llama_seq_id seq_id);
llama_seq_id seq_id);
// TODO: the llama_kv_cache_defrag and llama_kv_cache_update API tightly couples llama_context with llama_kv_cache
// how to avoid this?
// Defragment the KV cache
// This will be applied:
// - lazily on next llama_decode()
// - explicitly with llama_kv_self_update()
LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx);
// Check if the context supports KV cache shifting
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
// - explicitly with llama_kv_cache_update()
LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx);
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
DEPRECATED(LLAMA_API void llama_kv_cache_clear(
struct llama_context * ctx),
"use llama_kv_self_clear instead");
DEPRECATED(LLAMA_API bool llama_kv_cache_seq_rm(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1),
"use llama_kv_self_seq_rm instead");
DEPRECATED(LLAMA_API void llama_kv_cache_seq_cp(
struct llama_context * ctx,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1),
"use llama_kv_self_seq_cp instead");
DEPRECATED(LLAMA_API void llama_kv_cache_seq_keep(
struct llama_context * ctx,
llama_seq_id seq_id),
"use llama_kv_self_seq_keep instead");
DEPRECATED(LLAMA_API void llama_kv_cache_seq_add(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta),
"use llama_kv_self_seq_add instead");
DEPRECATED(LLAMA_API void llama_kv_cache_seq_div(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d),
"use llama_kv_self_seq_div instead");
DEPRECATED(LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
struct llama_context * ctx,
llama_seq_id seq_id),
"use llama_kv_self_seq_pos_max instead");
DEPRECATED(LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx),
"use llama_kv_self_defrag instead");
DEPRECATED(LLAMA_API bool llama_kv_cache_can_shift(const struct llama_context * ctx),
"use llama_kv_self_can_shift instead");
DEPRECATED(LLAMA_API void llama_kv_cache_update(struct llama_context * ctx),
"use llama_kv_self_update instead");
LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
// Check if the context supports KV cache shifting
LLAMA_API bool llama_kv_cache_can_shift(struct llama_context * ctx);
//
// State / sessions
@ -929,19 +856,14 @@ extern "C" {
// Frees a batch of tokens allocated with llama_batch_init()
LLAMA_API void llama_batch_free(struct llama_batch batch);
// Process a batch of tokens.
// In contrast to llama_decode() - this call does not use KV cache.
// For encode-decoder contexts, processes the batch using the encoder.
// Can store the encoder output internally for later use by the decoder's cross-attention layers.
// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
// Stores the encoder output internally for later use by the decoder cross-attention layers.
// 0 - success
// < 0 - error. the KV cache state is restored to the state before this call
LLAMA_API int32_t llama_encode(
struct llama_context * ctx,
struct llama_batch batch);
// Process a batch of tokens.
// Requires KV cache.
// For encode-decoder contexts, processes the batch using the decoder.
// Positive return values does not mean a fatal error, but rather a warning.
// 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
@ -969,10 +891,6 @@ extern "C" {
// If set to true, the model will only attend to the past tokens
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
// Set whether the model is in warmup mode or not
// If true, all model tensors are activated during llama_decode() to load and cache their weights.
LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup);
// Set abort callback
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
@ -1242,7 +1160,6 @@ extern "C" {
"will be removed in the future (see https://github.com/ggml-org/llama.cpp/pull/9896#discussion_r1800920915)");
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
/// Setting k <= 0 makes this a noop
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
@ -1289,38 +1206,22 @@ extern "C" {
float tau,
float eta);
/// @details Intializes a GBNF grammar, see grammars/README.md for details.
/// @param vocab The vocabulary that this grammar will be used with.
/// @param grammar_str The production rules for the grammar, encoded as a string. Returns an empty grammar if empty. Returns NULL if parsing of grammar_str fails.
/// @param grammar_root The name of the start symbol for the grammar.
LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
const struct llama_vocab * vocab,
const char * grammar_str,
const char * grammar_root);
DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy(
/// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639
/// @param trigger_words A list of words that will trigger the grammar sampler. This may be updated to a loose regex syntax (w/ ^) in a near future.
/// @param trigger_tokens A list of tokens that will trigger the grammar sampler.
LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy(
const struct llama_vocab * vocab,
const char * grammar_str,
const char * grammar_root,
const char ** trigger_words,
size_t num_trigger_words,
const llama_token * trigger_tokens,
size_t num_trigger_tokens),
"use llama_sampler_init_grammar_lazy_patterns instead");
/// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639
/// @param trigger_patterns A list of patterns that will trigger the grammar sampler. Pattern will be matched from the start of the generation output, and grammar sampler will be fed content starting from its first match group.
/// @param trigger_tokens A list of tokens that will trigger the grammar sampler. Grammar sampler will be fed content starting from the trigger token included.
LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy_patterns(
const struct llama_vocab * vocab,
const char * grammar_str,
const char * grammar_root,
const char ** trigger_patterns,
size_t num_trigger_patterns,
const llama_token * trigger_tokens,
size_t num_trigger_tokens);
size_t num_trigger_tokens);
/// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
@ -1438,37 +1339,6 @@ extern "C" {
LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain);
LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain);
//
// training
//
// function that returns whether or not a given tensor contains trainable parameters
typedef bool (*llama_opt_param_filter)(const struct ggml_tensor * tensor, void * userdata);
// always returns true
LLAMA_API bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata);
struct llama_opt_params {
uint32_t n_ctx_train; // assumed context size post training, use context size specified in llama_context if 0
llama_opt_param_filter param_filter; // callback for determining which tensors contain trainable parameters
void * param_filter_ud; // userdata for determining which tensors contain trainable parameters
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
};
LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params);
LLAMA_API void llama_opt_epoch(
struct llama_context * lctx,
ggml_opt_dataset_t dataset,
ggml_opt_result_t result_train,
ggml_opt_result_t result_eval,
int64_t idata_split,
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);
#ifdef __cplusplus
}
#endif

View File

@ -4,13 +4,14 @@
#include "llama-mmap.h"
#include "llama-model.h"
#include <algorithm>
#include <map>
#include <cassert>
#include <stdexcept>
// vec
ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) {
return nullptr;
}
@ -18,7 +19,7 @@ ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
return tensors[il];
}
ggml_tensor * llama_adapter_cvec::apply_to(ggml_context * ctx, ggml_tensor * cur, int il) const {
struct ggml_tensor * llama_adapter_cvec::apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const {
ggml_tensor * layer_dir = tensor_for(il);
if (layer_dir != nullptr) {
cur = ggml_add(ctx, cur, layer_dir);
@ -39,7 +40,7 @@ bool llama_adapter_cvec::init(const llama_model & model) {
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
ggml_init_params params = {
struct ggml_init_params params = {
/*.mem_size =*/ hparams.n_layer*ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
@ -90,7 +91,7 @@ bool llama_adapter_cvec::init(const llama_model & model) {
return true;
}
bool llama_adapter_cvec::apply(
int32_t llama_adapter_cvec::apply(
const llama_model & model,
const float * data,
size_t len,
@ -103,17 +104,17 @@ bool llama_adapter_cvec::apply(
// disable the current control vector (but leave allocated for later)
layer_start = -1;
layer_end = -1;
return true;
return 0;
}
if (n_embd != (int) hparams.n_embd) {
LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__);
return false;
return 1;
}
if (tensors.empty()) {
if (!init(model)) {
return false;
return 1;
}
}
@ -129,12 +130,12 @@ bool llama_adapter_cvec::apply(
}
}
return true;
return 0;
}
// lora
llama_adapter_lora_weight * llama_adapter_lora::get_weight(ggml_tensor * w) {
llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor * w) {
const std::string name(w->name);
const auto pos = ab_map.find(name);
@ -145,11 +146,11 @@ llama_adapter_lora_weight * llama_adapter_lora::get_weight(ggml_tensor * w) {
return nullptr;
}
static void llama_adapter_lora_init_impl(llama_model & model, const char * path_lora, llama_adapter_lora & adapter) {
static void llama_adapter_lora_init_impl(struct llama_model & model, const char * path_lora, struct llama_adapter_lora & adapter) {
LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
ggml_context * ctx_init;
gguf_init_params meta_gguf_params = {
struct gguf_init_params meta_gguf_params = {
/* .no_alloc = */ true,
/* .ctx = */ &ctx_init,
};
@ -200,7 +201,7 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
// add a new context
ggml_init_params params = {
struct ggml_init_params params = {
/*.mem_size =*/ n_tensors*ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
@ -247,29 +248,6 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
}
}
// get extra buffer types of the CPU
// TODO: a more general solution for non-CPU extra buft should be imlpemented in the future
// ref: https://github.com/ggml-org/llama.cpp/pull/12593#pullrequestreview-2718659948
std::vector<ggml_backend_buffer_type_t> buft_extra;
{
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (!cpu_dev) {
throw std::runtime_error(format("%s: no CPU backend found", __func__));
}
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
if (ggml_backend_dev_get_extra_bufts_fn) {
ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev);
while (extra_bufts && *extra_bufts) {
buft_extra.emplace_back(*extra_bufts);
++extra_bufts;
}
}
}
// add tensors
for (auto & it : ab_map) {
const std::string & name = it.first;
@ -286,26 +264,7 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)");
}
auto * buft = ggml_backend_buffer_get_type(model_tensor->buffer);
// do not load loras to extra buffer types (i.e. bufts for repacking) -> use the CPU in that case
for (auto & ex : buft_extra) {
if (ex == buft) {
LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (!cpu_dev) {
throw std::runtime_error(format("%s: no CPU backend found", __func__));
}
buft = ggml_backend_dev_buffer_type(cpu_dev);
break;
}
}
LLAMA_LOG_DEBUG("%s: lora for '%s' -> '%s'\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
ggml_context * dev_ctx = ctx_for_buft(buft);
struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
// validate tensor shape
if (is_token_embd) {
// expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd()
@ -322,8 +281,8 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
}
// save tensor to adapter
ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
ggml_set_name(tensor_a, w.a->name);
ggml_set_name(tensor_b, w.b->name);
adapter.ab_map[name] = llama_adapter_lora_weight(tensor_a, tensor_b);
@ -349,7 +308,7 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
{
llama_file gguf_file(path_lora, "rb");
std::vector<uint8_t> read_buf;
auto set_tensor = [&](ggml_tensor * orig, ggml_tensor * dev) {
auto set_tensor = [&](struct ggml_tensor * orig, struct ggml_tensor * dev) {
size_t offs = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), gguf_find_tensor(ctx_gguf.get(), orig->name));
size_t size = ggml_nbytes(orig);
read_buf.resize(size);
@ -368,8 +327,8 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
}
llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * path_lora) {
llama_adapter_lora * adapter = new llama_adapter_lora();
struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model, const char * path_lora) {
struct llama_adapter_lora * adapter = new llama_adapter_lora();
try {
llama_adapter_lora_init_impl(*model, path_lora, *adapter);
@ -383,6 +342,6 @@ llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * p
return nullptr;
}
void llama_adapter_lora_free(llama_adapter_lora * adapter) {
void llama_adapter_lora_free(struct llama_adapter_lora * adapter) {
delete adapter;
}

View File

@ -15,11 +15,11 @@
//
struct llama_adapter_cvec {
ggml_tensor * tensor_for(int il) const;
struct ggml_tensor * tensor_for(int il) const;
ggml_tensor * apply_to(ggml_context * ctx, ggml_tensor * cur, int il) const;
struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const;
bool apply(
int32_t apply(
const llama_model & model,
const float * data,
size_t len,
@ -36,7 +36,7 @@ private:
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
std::vector<ggml_tensor *> tensors; // per layer
std::vector<struct ggml_tensor *> tensors; // per layer
};
//
@ -44,8 +44,8 @@ private:
//
struct llama_adapter_lora_weight {
ggml_tensor * a = nullptr;
ggml_tensor * b = nullptr;
struct ggml_tensor * a = nullptr;
struct ggml_tensor * b = nullptr;
// get actual scale based on rank and alpha
float get_scale(float alpha, float adapter_scale) const {
@ -55,12 +55,12 @@ struct llama_adapter_lora_weight {
}
llama_adapter_lora_weight() = default;
llama_adapter_lora_weight(ggml_tensor * a, ggml_tensor * b) : a(a), b(b) {}
llama_adapter_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {}
};
struct llama_adapter_lora {
// map tensor name to lora_a_b
std::unordered_map<std::string, llama_adapter_lora_weight> ab_map;
std::unordered_map<std::string, struct llama_adapter_lora_weight> ab_map;
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
@ -70,7 +70,5 @@ struct llama_adapter_lora {
llama_adapter_lora() = default;
~llama_adapter_lora() = default;
llama_adapter_lora_weight * get_weight(ggml_tensor * w);
llama_adapter_lora_weight * get_weight(struct ggml_tensor * w);
};
using llama_adapter_loras = std::unordered_map<llama_adapter_lora *, float>;

View File

@ -6,7 +6,7 @@
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_LLAMA, "llama" },
{ LLM_ARCH_LLAMA4, "llama4" },
{ LLM_ARCH_MLLAMA, "mllama" },
{ LLM_ARCH_DECI, "deci" },
{ LLM_ARCH_FALCON, "falcon" },
{ LLM_ARCH_GROK, "grok" },
@ -19,7 +19,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_REFACT, "refact" },
{ LLM_ARCH_BERT, "bert" },
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
{ LLM_ARCH_BLOOM, "bloom" },
{ LLM_ARCH_STABLELM, "stablelm" },
@ -27,8 +26,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_QWEN2, "qwen2" },
{ LLM_ARCH_QWEN2MOE, "qwen2moe" },
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
{ LLM_ARCH_QWEN3, "qwen3" },
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
{ LLM_ARCH_PHI2, "phi2" },
{ LLM_ARCH_PHI3, "phi3" },
{ LLM_ARCH_PHIMOE, "phimoe" },
@ -55,7 +52,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_DEEPSEEK, "deepseek" },
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
{ LLM_ARCH_CHATGLM, "chatglm" },
{ LLM_ARCH_GLM4, "glm4" },
{ LLM_ARCH_BITNET, "bitnet" },
{ LLM_ARCH_T5, "t5" },
{ LLM_ARCH_T5ENCODER, "t5encoder" },
@ -64,15 +60,12 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_EXAONE, "exaone" },
{ LLM_ARCH_RWKV6, "rwkv6" },
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
{ LLM_ARCH_RWKV7, "rwkv7" },
{ LLM_ARCH_ARWKV7, "arwkv7" },
{ LLM_ARCH_GRANITE, "granite" },
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
{ LLM_ARCH_CHAMELEON, "chameleon" },
{ LLM_ARCH_SOLAR, "solar" },
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
{ LLM_ARCH_PLM, "plm" },
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
{ LLM_ARCH_MISTRAL3, "mistral3" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@ -81,7 +74,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" },
{ LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" },
{ LLM_KV_GENERAL_ALIGNMENT, "general.alignment" },
{ LLM_KV_GENERAL_FILE_TYPE, "general.file_type" },
{ LLM_KV_GENERAL_NAME, "general.name" },
{ LLM_KV_GENERAL_AUTHOR, "general.author" },
{ LLM_KV_GENERAL_VERSION, "general.version" },
@ -108,7 +100,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
@ -121,31 +112,25 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
{ LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" },
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
{ LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" },
{ LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" },
{ LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" },
{ LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
{ LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" },
{ LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" },
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
{ LLM_KV_ATTENTION_DECAY_LORA_RANK, "%s.attention.decay_lora_rank" },
{ LLM_KV_ATTENTION_ICLR_LORA_RANK, "%s.attention.iclr_lora_rank" },
{ LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, "%s.attention.value_residual_mix_lora_rank" },
{ LLM_KV_ATTENTION_GATE_LORA_RANK, "%s.attention.gate_lora_rank" },
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" },
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
{ LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" },
{ LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" },
{ LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" },
{ LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
{ LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" },
{ LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" },
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" },
{ LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, "%s.attention.cross_attention_layers" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
@ -245,7 +230,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
},
},
{
LLM_ARCH_LLAMA4,
LLM_ARCH_MLLAMA,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
@ -268,9 +253,14 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_CROSS_ATTN_K_NORM, "blk.%d.cross_attn_k_norm" },
{ LLM_TENSOR_CROSS_ATTN_K_PROJ, "blk.%d.cross_attn_k_proj" },
{ LLM_TENSOR_CROSS_ATTN_O_PROJ, "blk.%d.cross_attn_o_proj" },
{ LLM_TENSOR_CROSS_ATTN_Q_NORM, "blk.%d.cross_attn_q_norm" },
{ LLM_TENSOR_CROSS_ATTN_Q_PROJ, "blk.%d.cross_attn_q_proj" },
{ LLM_TENSOR_CROSS_ATTN_V_PROJ, "blk.%d.cross_attn_v_proj" },
{ LLM_TENSOR_CROSS_ATTN_ATTN_GATE, "blk.%d.cross_attn_attn_gate" },
{ LLM_TENSOR_CROSS_ATTN_MLP_GATE, "blk.%d.cross_attn_mlp_gate" },
},
},
{
@ -476,24 +466,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_NOMIC_BERT_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_JINA_BERT_V2,
{
@ -622,45 +594,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
LLM_ARCH_QWEN3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_QWEN3MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_PHI2,
{
@ -878,12 +811,9 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
@ -1127,8 +1057,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
@ -1145,22 +1073,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
},
},
{
LLM_ARCH_PLM,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
{ LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" },
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_CHATGLM,
{
@ -1179,25 +1091,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
},
},
{
LLM_ARCH_GLM4,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_BITNET,
{
@ -1382,74 +1275,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_RWKV7,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
{ LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
{ LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
{ LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
{ LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
{ LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
{ LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
{ LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
{ LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
{ LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
{ LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
{ LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
{ LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
{ LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
{ LLM_TENSOR_CHANNEL_MIX_LERP_K, "blk.%d.channel_mix_lerp_k" },
{ LLM_TENSOR_CHANNEL_MIX_KEY, "blk.%d.channel_mix_key" },
{ LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" },
},
},
{
LLM_ARCH_ARWKV7,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
{ LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
{ LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
{ LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
{ LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
{ LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
{ LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
{ LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
{ LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
{ LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
{ LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
{ LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
{ LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_GRANITE,
{
@ -1548,27 +1373,20 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
},
},
{
LLM_ARCH_BAILINGMOE,
LLM_ARCH_MISTRAL3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
}
},
{
LLM_ARCH_UNKNOWN,
@ -1607,8 +1425,23 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
@ -1635,12 +1468,6 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_A1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_A2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_V1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_V2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_G1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_G2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_DECAY_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_DECAY_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
@ -1659,9 +1486,6 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CHANNEL_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_K_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_K_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_R_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_LERP_W, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
@ -1669,9 +1493,6 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_W0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_A0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_V0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
{LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
@ -1701,6 +1522,14 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
// this tensor is loaded for T5, but never used
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
{LLM_TENSOR_BSKCN_TV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CROSS_ATTN_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CROSS_ATTN_K_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CROSS_ATTN_O_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CROSS_ATTN_Q_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CROSS_ATTN_Q_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CROSS_ATTN_V_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CROSS_ATTN_ATTN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CROSS_ATTN_MLP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}},
{LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},

View File

@ -10,7 +10,7 @@
enum llm_arch {
LLM_ARCH_LLAMA,
LLM_ARCH_LLAMA4,
LLM_ARCH_MLLAMA,
LLM_ARCH_DECI,
LLM_ARCH_FALCON,
LLM_ARCH_BAICHUAN,
@ -23,7 +23,6 @@ enum llm_arch {
LLM_ARCH_REFACT,
LLM_ARCH_BERT,
LLM_ARCH_NOMIC_BERT,
LLM_ARCH_NOMIC_BERT_MOE,
LLM_ARCH_JINA_BERT_V2,
LLM_ARCH_BLOOM,
LLM_ARCH_STABLELM,
@ -31,8 +30,6 @@ enum llm_arch {
LLM_ARCH_QWEN2,
LLM_ARCH_QWEN2MOE,
LLM_ARCH_QWEN2VL,
LLM_ARCH_QWEN3,
LLM_ARCH_QWEN3MOE,
LLM_ARCH_PHI2,
LLM_ARCH_PHI3,
LLM_ARCH_PHIMOE,
@ -59,7 +56,6 @@ enum llm_arch {
LLM_ARCH_DEEPSEEK,
LLM_ARCH_DEEPSEEK2,
LLM_ARCH_CHATGLM,
LLM_ARCH_GLM4,
LLM_ARCH_BITNET,
LLM_ARCH_T5,
LLM_ARCH_T5ENCODER,
@ -68,15 +64,12 @@ enum llm_arch {
LLM_ARCH_EXAONE,
LLM_ARCH_RWKV6,
LLM_ARCH_RWKV6QWEN2,
LLM_ARCH_RWKV7,
LLM_ARCH_ARWKV7,
LLM_ARCH_GRANITE,
LLM_ARCH_GRANITE_MOE,
LLM_ARCH_CHAMELEON,
LLM_ARCH_SOLAR,
LLM_ARCH_WAVTOKENIZER_DEC,
LLM_ARCH_PLM,
LLM_ARCH_BAILINGMOE,
LLM_ARCH_MISTRAL3,
LLM_ARCH_UNKNOWN,
};
@ -85,7 +78,6 @@ enum llm_kv {
LLM_KV_GENERAL_ARCHITECTURE,
LLM_KV_GENERAL_QUANTIZATION_VERSION,
LLM_KV_GENERAL_ALIGNMENT,
LLM_KV_GENERAL_FILE_TYPE,
LLM_KV_GENERAL_NAME,
LLM_KV_GENERAL_AUTHOR,
LLM_KV_GENERAL_VERSION,
@ -112,7 +104,6 @@ enum llm_kv {
LLM_KV_EXPERT_WEIGHTS_SCALE,
LLM_KV_EXPERT_WEIGHTS_NORM,
LLM_KV_EXPERT_GATING_FUNC,
LLM_KV_MOE_EVERY_N_LAYERS,
LLM_KV_POOLING_TYPE,
LLM_KV_LOGIT_SCALE,
LLM_KV_DECODER_START_TOKEN_ID,
@ -125,7 +116,6 @@ enum llm_kv {
LLM_KV_RESIDUAL_SCALE,
LLM_KV_EMBEDDING_SCALE,
LLM_KV_TOKEN_SHIFT_COUNT,
LLM_KV_INTERLEAVE_MOE_LAYER_STEP,
LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV,
@ -140,16 +130,11 @@ enum llm_kv {
LLM_KV_ATTENTION_CAUSAL,
LLM_KV_ATTENTION_Q_LORA_RANK,
LLM_KV_ATTENTION_KV_LORA_RANK,
LLM_KV_ATTENTION_DECAY_LORA_RANK,
LLM_KV_ATTENTION_ICLR_LORA_RANK,
LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK,
LLM_KV_ATTENTION_GATE_LORA_RANK,
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
LLM_KV_ATTENTION_SLIDING_WINDOW,
LLM_KV_ATTENTION_SCALE,
LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS,
LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_DIMENSION_SECTIONS,
@ -263,8 +248,6 @@ enum llm_tensor {
LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K_NORM,
LLM_TENSOR_LAYER_OUT_NORM,
LLM_TENSOR_POST_ATTN_NORM,
LLM_TENSOR_POST_MLP_NORM,
LLM_TENSOR_SSM_IN,
LLM_TENSOR_SSM_CONV1D,
LLM_TENSOR_SSM_X,
@ -272,20 +255,8 @@ enum llm_tensor {
LLM_TENSOR_SSM_A,
LLM_TENSOR_SSM_D,
LLM_TENSOR_SSM_OUT,
LLM_TENSOR_TIME_MIX_W0,
LLM_TENSOR_TIME_MIX_W1,
LLM_TENSOR_TIME_MIX_W2,
LLM_TENSOR_TIME_MIX_A0,
LLM_TENSOR_TIME_MIX_A1,
LLM_TENSOR_TIME_MIX_A2,
LLM_TENSOR_TIME_MIX_V0,
LLM_TENSOR_TIME_MIX_V1,
LLM_TENSOR_TIME_MIX_V2,
LLM_TENSOR_TIME_MIX_G1,
LLM_TENSOR_TIME_MIX_G2,
LLM_TENSOR_TIME_MIX_K_K,
LLM_TENSOR_TIME_MIX_K_A,
LLM_TENSOR_TIME_MIX_R_K,
LLM_TENSOR_TIME_MIX_LERP_X,
LLM_TENSOR_TIME_MIX_LERP_W,
LLM_TENSOR_TIME_MIX_LERP_K,
@ -312,8 +283,6 @@ enum llm_tensor {
LLM_TENSOR_ATTN_Q_B,
LLM_TENSOR_ATTN_KV_A_MQA,
LLM_TENSOR_ATTN_KV_B,
LLM_TENSOR_ATTN_K_B,
LLM_TENSOR_ATTN_V_B,
LLM_TENSOR_ATTN_Q_A_NORM,
LLM_TENSOR_ATTN_KV_A_NORM,
LLM_TENSOR_ATTN_SUB_NORM,
@ -349,6 +318,14 @@ enum llm_tensor {
LLM_TENSOR_CLS,
LLM_TENSOR_CLS_OUT,
LLM_TENSOR_BSKCN_TV,
LLM_TENSOR_CROSS_ATTN_K_NORM,
LLM_TENSOR_CROSS_ATTN_K_PROJ,
LLM_TENSOR_CROSS_ATTN_O_PROJ,
LLM_TENSOR_CROSS_ATTN_Q_NORM,
LLM_TENSOR_CROSS_ATTN_Q_PROJ,
LLM_TENSOR_CROSS_ATTN_V_PROJ,
LLM_TENSOR_CROSS_ATTN_ATTN_GATE,
LLM_TENSOR_CROSS_ATTN_MLP_GATE,
LLM_TENSOR_CONV1D,
LLM_TENSOR_CONVNEXT_DW,
LLM_TENSOR_CONVNEXT_NORM,

View File

@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
return ubatch;
}
llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
GGML_ASSERT(batch.n_tokens >= 0);
this->batch = &batch;
this->n_embd = n_embd;
@ -203,7 +203,6 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
for (size_t i = 0; i < n_tokens; ++i) {
ids[i] = i;
}
if (simple_split) {
seq.resize(1);
llama_sbatch_seq & s = seq[0];
@ -213,7 +212,6 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
s.length = n_tokens;
return;
}
std::sort(ids.begin(), ids.end(),
[&batch](size_t a, size_t b) {
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
@ -241,7 +239,6 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
return n_seq_a > n_seq_b;
}
);
// init seq
llama_sbatch_seq * last_seq = nullptr;
@ -265,7 +262,6 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
seq.push_back(new_seq);
last_seq = &seq.back();
}
// keep shared prompts first at the end, then sort by length descending.
std::sort(seq.begin(), seq.end(),
[](llama_sbatch_seq & a, llama_sbatch_seq & b) {
@ -320,6 +316,7 @@ struct llama_batch llama_batch_get_one(
/*n_tokens =*/ n_tokens,
/*tokens =*/ tokens,
/*embd =*/ nullptr,
/*n_embd =*/ 0,
/*pos =*/ nullptr,
/*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr,
@ -332,6 +329,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
/*n_tokens =*/ 0,
/*tokens =*/ nullptr,
/*embd =*/ nullptr,
/*n_embd =*/ 0,
/*pos =*/ nullptr,
/*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr,
@ -340,6 +338,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
if (embd) {
batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
batch.n_embd = embd;
} else {
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
}

View File

@ -42,9 +42,9 @@ struct llama_sbatch {
bool logits_all; // TODO: remove once lctx.logits_all is removed too
// sorted indices into the batch
std::vector<int64_t> ids;
std::vector<size_t> ids;
// batch indices of the output
std::vector<int64_t> out_ids;
std::vector<size_t> out_ids;
std::vector<llama_sbatch_seq> seq;
const llama_batch * batch = nullptr;
@ -70,8 +70,7 @@ struct llama_sbatch {
// sequence-wise split
llama_ubatch split_seq(size_t n_ubatch);
llama_sbatch() = default;
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
};
// temporary allocate memory for the input batch if needed

View File

@ -4,7 +4,6 @@
#include <map>
#include <sstream>
#include <algorithm>
#if __cplusplus >= 202000L
#define LU8(x) (const char*)(u8##x)
@ -35,7 +34,6 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 },
{ "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN },
{ "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 },
{ "mistral-v7-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN },
{ "phi3", LLM_CHAT_TEMPLATE_PHI_3 },
{ "phi4", LLM_CHAT_TEMPLATE_PHI_4 },
{ "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 },
@ -51,8 +49,8 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 },
{ "command-r", LLM_CHAT_TEMPLATE_COMMAND_R },
{ "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 },
{ "chatglm3", LLM_CHAT_TEMPLATE_CHATGLM_3 },
{ "chatglm4", LLM_CHAT_TEMPLATE_CHATGLM_4 },
{ "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 },
{ "chatglm4", LLM_CHAT_TEMPLATE_CHATGML_4 },
{ "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
@ -60,10 +58,6 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
{ "yandex", LLM_CHAT_TEMPLATE_YANDEX },
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
};
llm_chat_template llm_chat_template_from_str(const std::string & name) {
@ -83,9 +77,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
if (tmpl_contains("<|im_start|>")) {
return tmpl_contains("<|im_sep|>")
? LLM_CHAT_TEMPLATE_PHI_4
: tmpl_contains("<end_of_utterance>")
? LLM_CHAT_TEMPLATE_SMOLVLM // SmolVLM uses <|im_start|> as BOS, but it is NOT chatml
: LLM_CHAT_TEMPLATE_CHATML;
: LLM_CHAT_TEMPLATE_CHATML;
} else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) {
if (tmpl_contains("[SYSTEM_PROMPT]")) {
return LLM_CHAT_TEMPLATE_MISTRAL_V7;
@ -123,12 +115,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
}
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) {
return LLM_CHAT_TEMPLATE_PHI_3;
} else if (tmpl_contains("[gMASK]<sop>")) {
return LLM_CHAT_TEMPLATE_CHATGLM_4;
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
return tmpl_contains("</s>") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE;
} else if (tmpl_contains("<|{{ item['role'] }}|>") && tmpl_contains("<|begin_of_image|>")) {
return LLM_CHAT_TEMPLATE_GLMEDGE;
} else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) {
return LLM_CHAT_TEMPLATE_ZEPHYR;
} else if (tmpl_contains("bos_token + message['role']")) {
@ -157,7 +145,9 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_LLAMA_3;
} else if (tmpl_contains("[gMASK]sop")) {
// chatglm3-6b
return LLM_CHAT_TEMPLATE_CHATGLM_3;
return LLM_CHAT_TEMPLATE_CHATGML_3;
} else if (tmpl_contains("[gMASK]<sop>")) {
return LLM_CHAT_TEMPLATE_CHATGML_4;
} else if (tmpl_contains(LU8("<用户>"))) {
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
return LLM_CHAT_TEMPLATE_MINICPM;
@ -177,12 +167,6 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_GIGACHAT;
} else if (tmpl_contains("<|role_start|>")) {
return LLM_CHAT_TEMPLATE_MEGREZ;
} else if (tmpl_contains(" Ассистент:")) {
return LLM_CHAT_TEMPLATE_YANDEX;
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("'HUMAN'")) {
return LLM_CHAT_TEMPLATE_BAILING;
} else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
return LLM_CHAT_TEMPLATE_LLAMA4;
}
return LLM_CHAT_TEMPLATE_UNKNOWN;
}
@ -203,20 +187,19 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "<|im_start|>assistant\n";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN) {
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7) {
// Official mistral 'v7' template
// See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7
// https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503#basic-instruct-template-v7-tekken
const char * trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 ? " " : "";
for (auto message : chat) {
std::string role(message->role);
std::string content(message->content);
if (role == "system") {
ss << "[SYSTEM_PROMPT]" << trailing_space << content << "[/SYSTEM_PROMPT]";
ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]";
} else if (role == "user") {
ss << "[INST]" << trailing_space << content << "[/INST]";
} else {
ss << trailing_space << content << "</s>";
ss << "[INST] " << content << "[/INST]";
}
else {
ss << " " << content << "</s>";
}
}
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1
@ -439,7 +422,7 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_3) {
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_3) {
// chatglm3-6b
ss << "[gMASK]" << "sop";
for (auto message : chat) {
@ -449,14 +432,14 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "<|assistant|>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_4) {
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_4) {
ss << "[gMASK]" << "<sop>";
for (auto message : chat) {
std::string role(message->role);
ss << "<|" << role << "|>" << "\n" << message->content;
}
if (add_ass) {
ss << "<|assistant|>\n";
ss << "<|assistant|>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) {
for (auto message : chat) {
@ -583,66 +566,6 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "<|role_start|>assistant<|role_end|>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_YANDEX) {
// Yandex template ("\n\n" is defined as EOT token)
ss << "<s>";
for (size_t i = 0; i < chat.size(); i++) {
std::string role(chat[i]->role);
if (role == "user") {
ss << " Пользователь: " << chat[i]->content << "\n\n";
} else if (role == "assistant") {
ss << " Ассистент: " << chat[i]->content << "\n\n";
}
}
// Add generation prompt if needed
if (add_ass) {
ss << " Ассистент:[SEP]";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_BAILING) {
// Bailing (Ling) template
for (auto message : chat) {
std::string role(message->role);
if (role == "user") {
role = "HUMAN";
} else {
std::transform(role.begin(), role.end(), role.begin(), ::toupper);
}
ss << "<role>" << role << "</role>" << message->content;
}
if (add_ass) {
ss << "<role>ASSISTANT</role>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA4) {
// Llama 4
for (auto message : chat) {
std::string role(message->role);
ss << "<|header_start|>" << role << "<|header_end|>\n\n" << trim(message->content) << "<|eot|>";
}
if (add_ass) {
ss << "<|header_start|>assistant<|header_end|>\n\n";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_SMOLVLM) {
// SmolVLM
ss << "<|im_start|>"; // uses <|im_start|> as BOS, but the actual content is NOT chatml
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << message->content << "\n\n";
} else if (role == "user") {
ss << "User: " << message->content << "<end_of_utterance>\n";
} else {
ss << "Assistant: " << message->content << "<end_of_utterance>\n";
}
}
if (add_ass) {
ss << "Assistant:";
}
} else {
// template not supported
return -1;
@ -661,3 +584,4 @@ int32_t llama_chat_builtin_templates(const char ** output, size_t len) {
}
return (int32_t) LLM_CHAT_TEMPLATES.size();
}

View File

@ -14,7 +14,6 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_MISTRAL_V3,
LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN,
LLM_CHAT_TEMPLATE_MISTRAL_V7,
LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN,
LLM_CHAT_TEMPLATE_PHI_3,
LLM_CHAT_TEMPLATE_PHI_4,
LLM_CHAT_TEMPLATE_FALCON_3,
@ -30,8 +29,8 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_DEEPSEEK_3,
LLM_CHAT_TEMPLATE_COMMAND_R,
LLM_CHAT_TEMPLATE_LLAMA_3,
LLM_CHAT_TEMPLATE_CHATGLM_3,
LLM_CHAT_TEMPLATE_CHATGLM_4,
LLM_CHAT_TEMPLATE_CHATGML_3,
LLM_CHAT_TEMPLATE_CHATGML_4,
LLM_CHAT_TEMPLATE_GLMEDGE,
LLM_CHAT_TEMPLATE_MINICPM,
LLM_CHAT_TEMPLATE_EXAONE_3,
@ -39,10 +38,6 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_GRANITE,
LLM_CHAT_TEMPLATE_GIGACHAT,
LLM_CHAT_TEMPLATE_MEGREZ,
LLM_CHAT_TEMPLATE_YANDEX,
LLM_CHAT_TEMPLATE_BAILING,
LLM_CHAT_TEMPLATE_LLAMA4,
LLM_CHAT_TEMPLATE_SMOLVLM,
LLM_CHAT_TEMPLATE_UNKNOWN,
};

File diff suppressed because it is too large Load Diff

View File

@ -3,222 +3,66 @@
#include "llama.h"
#include "llama-batch.h"
#include "llama-cparams.h"
#include "llama-graph.h"
#include "llama-adapter.h"
#include "llama-model.h"
#include "llama-kv-cache.h"
#include "llama-adapter.h"
#include "ggml-cpp.h"
#include "ggml-opt.h"
#include <map>
#include <unordered_map>
#include <vector>
struct llama_model;
struct llama_kv_cache;
class llama_io_read_i;
class llama_io_write_i;
#include <set>
struct llama_context {
// init scheduler and compute buffers, reserve worst-case graphs
llama_context(
const llama_model & model,
llama_context_params params);
llama_context(const llama_model & model)
: model(model)
, t_start_us(model.t_start_us)
, t_load_us(model.t_load_us) {}
~llama_context();
const struct llama_model & model;
void synchronize();
struct llama_cparams cparams;
struct llama_sbatch sbatch; // TODO: revisit if needed
struct llama_kv_cache kv_self;
struct llama_adapter_cvec cvec;
const llama_model & get_model() const;
const llama_cparams & get_cparams() const;
std::unordered_map<struct llama_adapter_lora *, float> lora;
ggml_backend_sched_t get_sched() const;
std::vector<ggml_backend_ptr> backends;
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
ggml_context * get_ctx_compute() const;
ggml_backend_t backend_cpu = nullptr;
uint32_t n_ctx() const;
uint32_t n_ctx_per_seq() const;
uint32_t n_batch() const;
uint32_t n_ubatch() const;
uint32_t n_seq_max() const;
ggml_threadpool_t threadpool = nullptr;
ggml_threadpool_t threadpool_batch = nullptr;
uint32_t n_threads() const;
uint32_t n_threads_batch() const;
bool has_evaluated_once = false;
llama_kv_cache * get_kv_self();
const llama_kv_cache * get_kv_self() const;
mutable int64_t t_start_us;
mutable int64_t t_load_us;
mutable int64_t t_p_eval_us = 0;
mutable int64_t t_eval_us = 0;
void kv_self_update();
mutable int64_t t_compute_start_us = 0;
mutable int64_t n_queued_tokens = 0;
enum llama_pooling_type pooling_type() const;
mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
mutable int32_t n_eval = 0; // number of eval calls
float * get_logits();
float * get_logits_ith(int32_t i);
float * get_embeddings();
float * get_embeddings_ith(int32_t i);
float * get_embeddings_seq(llama_seq_id seq_id);
void attach_threadpool(
ggml_threadpool_t threadpool,
ggml_threadpool_t threadpool_batch);
void detach_threadpool();
void set_n_threads(int32_t n_threads, int32_t n_threads_batch);
void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data);
void set_embeddings (bool value);
void set_causal_attn(bool value);
void set_warmup(bool value);
void set_adapter_lora(
llama_adapter_lora * adapter,
float scale);
bool rm_adapter_lora(
llama_adapter_lora * adapter);
void clear_adapter_lora();
bool apply_adapter_cvec(
const float * data,
size_t len,
int32_t n_embd,
int32_t il_start,
int32_t il_end);
int encode(llama_batch & inp_batch);
int decode(llama_batch & inp_batch);
//
// state save/load
//
size_t state_get_size();
size_t state_get_data( uint8_t * dst, size_t size);
size_t state_set_data(const uint8_t * src, size_t size);
size_t state_seq_get_size(llama_seq_id seq_id);
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size);
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size);
bool state_load_file(
const char * filepath,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out);
bool state_save_file(
const char * filepath,
const llama_token * tokens,
size_t n_token_count);
size_t state_seq_load_file(
llama_seq_id seq_id,
const char * filepath,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out);
size_t state_seq_save_file(
llama_seq_id seq_id,
const char * filepath,
const llama_token * tokens,
size_t n_token_count);
//
// perf
//
llama_perf_context_data perf_get_data() const;
void perf_reset();
//
// training
//
void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
void opt_epoch(
ggml_opt_dataset_t dataset,
ggml_opt_result_t result_train,
ggml_opt_result_t result_eval,
int64_t idata_split,
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);
void opt_epoch_iter(
ggml_opt_dataset_t dataset,
ggml_opt_result_t result,
const std::vector<llama_token> & tokens,
const std::vector<llama_token> & labels_sparse,
llama_batch & batch,
ggml_opt_epoch_callback callback,
bool train,
int64_t idata_in_loop,
int64_t ndata_in_loop,
int64_t t_loop_start);
private:
//
// output
//
// Make sure enough space is available for outputs.
// Returns max number of outputs for which space was reserved.
int32_t output_reserve(int32_t n_outputs);
//
// graph
//
public:
int32_t graph_max_nodes() const;
// zero-out inputs and create the ctx_compute for the compute graph
ggml_cgraph * graph_init();
// returns the result of ggml_backend_sched_graph_compute_async execution
ggml_status graph_compute(
ggml_cgraph * gf,
bool batched);
private:
llm_graph_result_ptr graph_build(
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
llm_graph_type gtype);
llm_graph_cb graph_get_cb() const;
// TODO: read/write lora adapters and cvec
size_t state_write_data(llama_io_write_i & io);
size_t state_read_data (llama_io_read_i & io);
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id);
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id);
//
// members
//
const llama_model & model;
llama_cparams cparams;
llama_adapter_cvec cvec;
llama_adapter_loras loras;
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
std::unique_ptr<llama_memory_i> memory;
// host buffer for the model output (logits and embeddings)
ggml_backend_buffer_ptr buf_output;
// decode output (2-dimensional array: [n_outputs][n_vocab])
size_t logits_size = 0; // capacity (of floats) for logits
float * logits = nullptr;
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
bool logits_all = false;
// embeddings output (2-dimensional array: [n_outputs][n_embd])
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
size_t embd_size = 0; // capacity (of floats) for embeddings
@ -228,50 +72,59 @@ private:
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
std::map<llama_seq_id, std::vector<float>> embd_seq;
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
int32_t n_outputs_max = 0; // capacity (of tokens positions) for the output buffers
// whether we are computing encoder output or decoder output
bool is_encoding = false;
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
// TODO: find a better way to accommodate mutli-dimension position encoding methods
// number of position id each token get, 1 for each token in most cases.
// when using m-rope, it will be 3 position ids per token to representing 3 dimension coordinate.
int n_pos_per_token = 1;
// output of the encoder part of the encoder-decoder models
std::vector<float> embd_enc;
std::vector<std::set<llama_seq_id>> seq_ids_enc;
// memory buffers used to evaluate the model
std::vector<uint8_t> buf_compute_meta;
ggml_backend_sched_ptr sched;
ggml_backend_t backend_cpu = nullptr;
std::vector<ggml_backend_ptr> backends;
ggml_context_ptr ctx_compute;
// training
ggml_opt_context_t opt_ctx = nullptr;
ggml_threadpool_t threadpool = nullptr;
ggml_threadpool_t threadpool_batch = nullptr;
ggml_abort_callback abort_callback = nullptr;
void * abort_callback_data = nullptr;
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
// input tensors
struct ggml_tensor * inp_tokens; // I32 [n_batch]
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
struct ggml_tensor * inp_pos; // I32 [n_batch]
struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch]
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
struct ggml_tensor * inp_cls; // I32 [n_batch]
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
// buffer types used for the compute buffer of each backend
std::vector<ggml_backend_t> backend_ptrs;
std::vector<ggml_backend_buffer_type_t> backend_buft;
// memory buffers used to evaluate the model
std::vector<uint8_t> buf_compute_meta;
// host buffer for the model output (logits and embeddings)
ggml_backend_buffer_ptr buf_output;
bool has_evaluated_once = false;
// perf
mutable int64_t t_start_us = 0;
mutable int64_t t_load_us = 0;
mutable int64_t t_p_eval_us = 0;
mutable int64_t t_eval_us = 0;
mutable int64_t t_compute_start_us = 0;
mutable int64_t n_queued_tokens = 0;
mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
mutable int32_t n_eval = 0; // number of eval calls
struct ggml_tensor * inp_cross_attn_state; // F32 [4, n_embd, 1061]
};
// TODO: make these methods of llama_context
void llama_set_k_shift(struct llama_context & lctx);
void llama_set_s_copy(struct llama_context & lctx);
void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch);
// Make sure enough space is available for outputs.
// Returns max number of outputs for which space was reserved.
size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs);
// make the outputs have the same order they had in the user-provided batch
void llama_output_reorder(struct llama_context & ctx);
// For internal test use
// TODO: remove
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(struct llama_context * ctx);

View File

@ -29,8 +29,7 @@ struct llama_cparams {
bool offload_kqv;
bool flash_attn;
bool no_perf;
bool warmup;
bool op_offload;
bool cross_attn;
enum llama_pooling_type pooling_type;

View File

@ -907,7 +907,6 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab * vocab,
const struct ollama_vocab * ollama_vocab,
const llama_grammar_element ** rules,
size_t n_rules,
size_t start_rule_index) {
@ -963,7 +962,6 @@ struct llama_grammar * llama_grammar_init_impl(
// then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar {
vocab,
ollama_vocab,
std::move(vec_rules),
std::move(stacks),
/* .partial_utf8 = */ {},
@ -971,25 +969,28 @@ struct llama_grammar * llama_grammar_init_impl(
/* .awaiting_trigger = */ false,
/* .trigger_buffer = */ "",
/* .trigger_tokens = */ {},
/* .trigger_patterns = */ {},
/* .trigger_words = */ {},
};
}
struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab * vocab,
const struct ollama_vocab * ollama_vocab,
const char * grammar_str,
const char * grammar_root,
bool lazy,
const char ** trigger_patterns,
size_t num_trigger_patterns,
const char ** trigger_words,
size_t num_trigger_words,
const llama_token * trigger_tokens,
size_t num_trigger_tokens) {
llama_grammar_parser parser;
// if there is a grammar, parse it
// rules will be empty (default) if there are parse errors
if (!parser.parse(grammar_str) || parser.rules.empty()) {
if (!parser.parse(grammar_str)) {
return nullptr;
}
// will be empty (default) if there are parse errors
if (parser.rules.empty()) {
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
return nullptr;
}
@ -1053,16 +1054,14 @@ struct llama_grammar * llama_grammar_init_impl(
} while (true);
std::vector<llama_token> vec_trigger_tokens;
std::vector<llama_grammar_trigger_pattern> vec_trigger_patterns;
std::vector<std::string> vec_trigger_words;
for (size_t i = 0; i < num_trigger_tokens; i++) {
GGML_ASSERT(trigger_tokens != nullptr);
vec_trigger_tokens.push_back(trigger_tokens[i]);
}
for (size_t i = 0; i < num_trigger_patterns; i++) {
GGML_ASSERT(trigger_patterns != nullptr);
auto & trigger = vec_trigger_patterns.emplace_back();
trigger.pattern = trigger_patterns[i];
trigger.regex = std::regex(trigger.pattern);
for (size_t i = 0; i < num_trigger_words; i++) {
GGML_ASSERT(trigger_words != nullptr);
vec_trigger_words.push_back(trigger_words[i]);
}
// Important: vec_rules has to be moved here, not copied, because stacks contains
@ -1070,7 +1069,6 @@ struct llama_grammar * llama_grammar_init_impl(
// then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar {
vocab,
ollama_vocab,
std::move(vec_rules),
std::move(stacks),
/* .partial_utf8 = */ {},
@ -1078,7 +1076,7 @@ struct llama_grammar * llama_grammar_init_impl(
/* .awaiting_trigger = */ lazy,
/* .trigger_buffer = */ "",
std::move(vec_trigger_tokens),
std::move(vec_trigger_patterns),
std::move(vec_trigger_words),
};
}
@ -1091,9 +1089,8 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
}
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
auto * result = new llama_grammar {
llama_grammar * result = new llama_grammar {
grammar.vocab,
grammar.o_vocab,
grammar.rules,
grammar.stacks,
grammar.partial_utf8,
@ -1101,7 +1098,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
grammar.awaiting_trigger,
grammar.trigger_buffer,
grammar.trigger_tokens,
grammar.trigger_patterns,
grammar.trigger_words,
};
// redirect elements in stacks to point to new rules
@ -1121,6 +1118,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
}
void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
GGML_ASSERT(grammar.vocab != nullptr);
if (grammar.awaiting_trigger) {
return;
@ -1142,13 +1140,9 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
for (size_t i = 0; i < cur_p->size; ++i) {
const llama_token id = cur_p->data[i].id;
const std::string piece = grammar.o_vocab ?
grammar.o_vocab->token_to_piece(id) :
grammar.vocab->token_to_piece(id);
const std::string & piece = grammar.vocab->token_to_piece(id);
const bool is_eog = grammar.o_vocab ? grammar.o_vocab->is_eog(id) : grammar.vocab->is_eog(id);
if (is_eog) {
if (grammar.vocab->is_eog(id)) {
if (!allow_eog) {
cur_p->data[i].logit = -INFINITY;
}
@ -1167,10 +1161,9 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
}
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
GGML_ASSERT(grammar.vocab != nullptr);
const std::string piece = grammar.o_vocab ?
grammar.o_vocab->token_to_piece(token) :
grammar.vocab->token_to_piece(token);
const auto & piece = grammar.vocab->token_to_piece(token);
if (grammar.awaiting_trigger) {
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
@ -1180,18 +1173,16 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
return;
} else {
// TODO: consider a smarter incremental substring search algorithm (store last position to search from).
grammar.trigger_buffer += piece;
std::smatch match;
for (const auto & trigger_pattern : grammar.trigger_patterns) {
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
for (const auto & word : grammar.trigger_words) {
auto pos = grammar.trigger_buffer.find(word);
if (pos != std::string::npos) {
grammar.awaiting_trigger = false;
// get from the first match to the end of the string
auto constrained_str = grammar.trigger_buffer.substr(match.position(1));
// std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
auto constrained_str = grammar.trigger_buffer.substr(pos);
grammar.trigger_buffer.clear();
llama_grammar_accept_str(grammar, constrained_str);
LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str());
LLAMA_LOG_DEBUG("Grammar triggered on word `%s`", word.c_str());
return;
}
}
@ -1200,14 +1191,13 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
}
}
const bool is_eog = grammar.o_vocab ? grammar.o_vocab->is_eog(token) : grammar.vocab->is_eog(token);
if (is_eog) {
if (grammar.vocab->is_eog(token)) {
for (const auto & stack : grammar.stacks) {
if (stack.empty()) {
return;
}
}
GGML_ABORT("grammar error: end of grammar token received but grammar stack is not empty");
GGML_ABORT("fatal error");
}
llama_grammar_accept_str(grammar, piece);
@ -1227,28 +1217,3 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
}
}
const std::string & ollama_vocab::token_to_piece(const uint32_t token) const {
try {
return token_to_piece_map.at(token);
} catch (const std::out_of_range&) {
throw std::runtime_error("Token not found in vocabulary: " + std::to_string(token));
}
}
void ollama_vocab::add_token_pieces(const uint32_t* tokens, size_t n_tokens, const char** pieces) {
for (size_t i = 0; i < n_tokens; i++) {
token_to_piece_map[tokens[i]] = pieces[i];
}
}
bool ollama_vocab::is_eog(const uint32_t token) const {
return special_eog_ids.count(token) > 0;
}
void ollama_vocab::set_eog_tokens(const uint32_t* tokens, size_t n_tokens) {
for (size_t i = 0; i < n_tokens; i++) {
special_eog_ids.insert(tokens[i]);
}
}

View File

@ -3,22 +3,10 @@
#include "llama.h"
#include <map>
#include <regex>
#include <string>
#include <vector>
#include <set>
struct llama_vocab;
struct ollama_vocab {
std::map<uint32_t, std::string> token_to_piece_map;
std::set<uint32_t> special_eog_ids;
const std::string & token_to_piece(const uint32_t token) const;
void add_token_pieces(const uint32_t* tokens, size_t n_tokens, const char** pieces);
void set_eog_tokens(const uint32_t* tokens, size_t n_tokens);
bool is_eog(const uint32_t token) const;
};
// grammar element type
enum llama_gretype {
@ -117,15 +105,9 @@ struct llama_grammar_parser {
void print(FILE * file);
};
struct llama_grammar_trigger_pattern {
std::string pattern;
std::regex regex;
};
struct llama_grammar {
// note: allow null vocab for testing (not great)
const llama_vocab * vocab;
const ollama_vocab * o_vocab;
const llama_grammar_rules rules; // TODO: shared ptr
llama_grammar_stacks stacks;
@ -140,10 +122,7 @@ struct llama_grammar {
bool awaiting_trigger = false; // Initialized to true for lazy grammars only
std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
std::vector<llama_grammar_trigger_pattern>
trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated
// string, and the grammar will be given the string from the first match group onwards.
std::vector<std::string> trigger_words;
};
//
@ -153,19 +132,17 @@ struct llama_grammar {
// note: needed for tests (not great)
struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab * vocab,
const struct ollama_vocab * ollama_vocab,
const llama_grammar_element ** rules,
size_t n_rules,
size_t start_rule_index);
struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab * vocab,
const struct ollama_vocab * ollama_vocab,
const char * grammar_str,
const char * grammar_root,
bool lazy,
const char ** trigger_patterns,
size_t num_trigger_patterns,
const char ** trigger_words,
size_t num_trigger_words,
const llama_token * trigger_tokens,
size_t num_trigger_tokens);

Some files were not shown because too many files have changed in this diff Show More