Compare commits

..

33 Commits

Author SHA1 Message Date
Devon Rifkin
ad3c7c9bda strip out thinking tags in message history for qwen3 & r1 (#10490)
* strip out thinking tags in message history for qwen3 & r1

This is in advance of "proper" support where we'll make reasoning
configurable and we'll parse out thinking/reasoning tags and provide
them to the caller. These models expect there to be no thinking tags in
the message history, so this should improve quality

* parse model names instead of hacky prefix check
2025-04-30 13:57:45 -07:00
Daniel Hiltgen
415c8fcc3d Fix "Stopping..." scheduler hang (#10487)
* Adjust initial scheduler refCount

Ensure we only set the refCount on success

* sched: fix lock order inversion deadlock

Under certain race conditions, there was a scenario where the scheduler would
get into a deadlock while trying to update free space information while a model
was trying to unload.
2025-04-30 11:26:52 -07:00
Daniel Hiltgen
718eda1b3e Narrow set of paths we load GGML from (#10485)
Users may have other incompatible GGML installs on their systems.
This will prevent us from trying to load them from the path.
2025-04-30 11:25:22 -07:00
Shahin R
421b7edeb4 readme: add link to lumina, a lightweight React frontend client (#10378) 2025-04-30 09:50:47 -07:00
batuhankadioglu
7b68e254c2 all: update several golang.org/x packages (#10436) 2025-04-29 16:51:09 -07:00
Daniel Hiltgen
7bec2724a5 integration: fix embedding tests error handling (#10478)
The cleanup routine from InitServerconnection should run in the defer of the test case to properly detect failures and report the server logs
2025-04-29 11:57:54 -07:00
Jesse Gross
a27462b708 ollamarunner: Temporarily disable worst case graph preallocation
When we later have a large batch running purely on a CPU, this
results the error:
GGML_ASSERT(talloc->buffer_id >= 0)

Disabling this means that we will incrementally reallocate memory
as the graph grows.

Fixes #10410
2025-04-29 11:04:58 -07:00
crStiv
6bf0b8193a readme: fix typos (#10399) 2025-04-29 10:30:44 -07:00
Devon Rifkin
db428adbb8 Merge pull request #10468 from ollama/drifkin/num-parallel-1 2025-04-29 10:21:36 -07:00
Devon Rifkin
fe5b9bb21b lower default num parallel to 2
this is in part to "pay" for #10452, which doubled the default context length. The combination isn't fully neutral though, because even though the old 4x2k limit and the new 2x4k limit are memory equivalent, the 1x fallback is larger with 4k
2025-04-29 02:04:14 -07:00
Devon Rifkin
6ec71d8fb6 Merge pull request #10452 from ollama/drifkin/4096-context-length
config: update default context length to 4096
2025-04-28 17:13:51 -07:00
Devon Rifkin
44b466eeb2 config: update default context length to 4096 2025-04-28 17:03:27 -07:00
Devon Rifkin
a25f3f8260 Merge pull request #10451 from ollama/revert-10364-drifkin/context-length
Revert "increase default context length to 4096"
2025-04-28 17:02:10 -07:00
Devon Rifkin
dd93e1af85 Revert "increase default context length to 4096 (#10364)"
This reverts commit 424f648632.
2025-04-28 16:54:11 -07:00
Michael Yang
5cfc1c39f3 model: fix build (#10416) 2025-04-25 19:24:48 -07:00
Michael Yang
f0ad49ea17 memory 2025-04-25 16:59:20 -07:00
Michael Yang
7ba9fa9c7d fixes for maverick 2025-04-25 16:59:20 -07:00
Michael Yang
8bf11b84c1 chunked attention 2025-04-25 16:59:20 -07:00
Michael Yang
470af8ab89 connect vision to text 2025-04-25 16:59:20 -07:00
Michael Yang
178761aef3 image processing
Co-authored-by: Patrick Devine <patrick@infrahq.com>
2025-04-25 16:59:20 -07:00
Michael Yang
f0c66e6dea llama4 2025-04-25 16:59:20 -07:00
Michael Yang
54055a6dae fix test 2025-04-25 16:59:01 -07:00
Michael Yang
340448d2d1 explicitly decode maxarraysize 1024 2025-04-25 16:59:01 -07:00
Michael Yang
ced7d0e53d fix parameter count 2025-04-25 16:59:01 -07:00
Michael Yang
a0dba0f8ae default slice values 2025-04-25 16:59:01 -07:00
Michael Yang
5e20b170a7 update comment 2025-04-25 16:59:01 -07:00
Michael Yang
d26c18e25c fix token type 2025-04-25 16:59:01 -07:00
Michael Yang
8d376acc9b zero means zero
use a default of 1024 when asking for zero is confusing since most calls
seem to assume 0 means do not ready any data
2025-04-25 16:59:01 -07:00
Michael Yang
dc1e81f027 convert: use -1 for read all 2025-04-25 16:59:01 -07:00
Michael Yang
5d0279164c generic ggml.array 2025-04-25 16:59:01 -07:00
Michael Yang
214a7678ea fix superfluous call to WriteHeader
the first call to http.ResponseWriter.Write implicitly calls WriteHeader
with http.StatusOK if it hasn't already been called. once WriteHeader
has been called, subsequent calls has no effect. Write is called when
JSON encoding progressUpdateJSON{}. calls to
http.ResponseWriter.WriteHeader after the first encode is useless and
produces a warning:

http: superfluous response.WriteHeader call from github.com/ollama/ollama/server/internal/registry.(*statusCodeRecorder).WriteHeader (server.go:77)
2025-04-25 16:58:49 -07:00
Michael Yang
4892872c18 convert: change to colmajor 2025-04-25 15:27:39 -07:00
Michael Yang
0b9198bf47 ci: silence deprecated gpu targets warning 2025-04-25 13:37:54 -07:00
58 changed files with 2026 additions and 406 deletions

View File

@@ -21,14 +21,16 @@
"name": "CUDA 11", "name": "CUDA 11",
"inherits": [ "CUDA" ], "inherits": [ "CUDA" ],
"cacheVariables": { "cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;70;75;80;86" "CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;70;75;80;86",
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
} }
}, },
{ {
"name": "CUDA 12", "name": "CUDA 12",
"inherits": [ "CUDA" ], "inherits": [ "CUDA" ],
"cacheVariables": { "cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120" "CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120",
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
} }
}, },
{ {

View File

@@ -285,7 +285,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt) - [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
- [HTML UI](https://github.com/rtcfirefly/ollama-ui) - [HTML UI](https://github.com/rtcfirefly/ollama-ui)
- [Saddle](https://github.com/jikkuatwork/saddle) - [Saddle](https://github.com/jikkuatwork/saddle)
- [TagSpaces](https://www.tagspaces.org) (A platform for file based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions) - [TagSpaces](https://www.tagspaces.org) (A platform for file-based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama) - [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui) - [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file) - [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
@@ -325,14 +325,14 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama) - [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama)
- [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models) - [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models)
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama) - [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
- [Casibase](https://casibase.org) (An open source AI knowledge base and dialogue system combining the latest RAG, SSO, ollama support and multiple large language models.) - [Casibase](https://casibase.org) (An open source AI knowledge base and dialogue system combining the latest RAG, SSO, ollama support, and multiple large language models.)
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS) - [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama) - [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
- [Shinkai Desktop](https://github.com/dcSpark/shinkai-apps) (Two click install Local AI using Ollama + Files + RAG) - [Shinkai Desktop](https://github.com/dcSpark/shinkai-apps) (Two click install Local AI using Ollama + Files + RAG)
- [AiLama](https://github.com/zeyoyt/ailama) (A Discord User App that allows you to interact with Ollama anywhere in discord ) - [AiLama](https://github.com/zeyoyt/ailama) (A Discord User App that allows you to interact with Ollama anywhere in Discord)
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama) - [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
- [R2R](https://github.com/SciPhi-AI/R2R) (Open-source RAG engine) - [R2R](https://github.com/SciPhi-AI/R2R) (Open-source RAG engine)
- [Ollama-Kis](https://github.com/elearningshow/ollama-kis) (A simple easy to use GUI with sample custom LLM for Drivers Education) - [Ollama-Kis](https://github.com/elearningshow/ollama-kis) (A simple easy-to-use GUI with sample custom LLM for Drivers Education)
- [OpenGPA](https://opengpa.org) (Open-source offline-first Enterprise Agentic Application) - [OpenGPA](https://opengpa.org) (Open-source offline-first Enterprise Agentic Application)
- [Painting Droid](https://github.com/mateuszmigas/painting-droid) (Painting app with AI integrations) - [Painting Droid](https://github.com/mateuszmigas/painting-droid) (Painting app with AI integrations)
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS) - [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
@@ -341,16 +341,16 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LLMStack](https://github.com/trypromptly/LLMStack) (No-code multi-agent framework to build LLM agents and workflows) - [LLMStack](https://github.com/trypromptly/LLMStack) (No-code multi-agent framework to build LLM agents and workflows)
- [BoltAI for Mac](https://boltai.com) (AI Chat Client for Mac) - [BoltAI for Mac](https://boltai.com) (AI Chat Client for Mac)
- [Harbor](https://github.com/av/harbor) (Containerized LLM Toolkit with Ollama as default backend) - [Harbor](https://github.com/av/harbor) (Containerized LLM Toolkit with Ollama as default backend)
- [PyGPT](https://github.com/szczyglis-dev/py-gpt) (AI desktop assistant for Linux, Windows and Mac) - [PyGPT](https://github.com/szczyglis-dev/py-gpt) (AI desktop assistant for Linux, Windows, and Mac)
- [Alpaca](https://github.com/Jeffser/Alpaca) (An Ollama client application for linux and macos made with GTK4 and Adwaita) - [Alpaca](https://github.com/Jeffser/Alpaca) (An Ollama client application for Linux and macOS made with GTK4 and Adwaita)
- [AutoGPT](https://github.com/Significant-Gravitas/AutoGPT/blob/master/docs/content/platform/ollama.md) (AutoGPT Ollama integration) - [AutoGPT](https://github.com/Significant-Gravitas/AutoGPT/blob/master/docs/content/platform/ollama.md) (AutoGPT Ollama integration)
- [Go-CREW](https://www.jonathanhecl.com/go-crew/) (Powerful Offline RAG in Golang) - [Go-CREW](https://www.jonathanhecl.com/go-crew/) (Powerful Offline RAG in Golang)
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery) - [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot and Ollama4j - [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot, and Ollama4j
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models. - [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VSCode extension for multi-file/whole-repo coding - [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VSCode extension for multi-file/whole-repo coding
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support) - [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption) - [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library) - [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
- [crewAI with Mesop](https://github.com/rapidarchitect/ollama-crew-mesop) (Mesop Web Interface to run crewAI with Ollama) - [crewAI with Mesop](https://github.com/rapidarchitect/ollama-crew-mesop) (Mesop Web Interface to run crewAI with Ollama)
- [Tkinter-based client](https://github.com/chyok/ollama-gui) (Python tkinter-based Client for Ollama) - [Tkinter-based client](https://github.com/chyok/ollama-gui) (Python tkinter-based Client for Ollama)
@@ -368,7 +368,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface) - [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
- [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol) - [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol)
- [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app) - [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app)
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard and said in the meetings) - [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard, and said in the meetings)
- [Hexabot](https://github.com/hexastack/hexabot) (A conversational AI builder) - [Hexabot](https://github.com/hexastack/hexabot) (A conversational AI builder)
- [Reddit Rate](https://github.com/rapidarchitect/reddit_analyzer) (Search and Rate Reddit topics with a weighted summation) - [Reddit Rate](https://github.com/rapidarchitect/reddit_analyzer) (Search and Rate Reddit topics with a weighted summation)
- [OpenTalkGpt](https://github.com/adarshM84/OpenTalkGpt) (Chrome Extension to manage open-source models supported by Ollama, create custom models, and chat with models from a user-friendly UI) - [OpenTalkGpt](https://github.com/adarshM84/OpenTalkGpt) (Chrome Extension to manage open-source models supported by Ollama, create custom models, and chat with models from a user-friendly UI)
@@ -386,7 +386,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [ChibiChat](https://github.com/CosmicEventHorizon/ChibiChat) (Kotlin-based Android app to chat with Ollama and Koboldcpp API endpoints) - [ChibiChat](https://github.com/CosmicEventHorizon/ChibiChat) (Kotlin-based Android app to chat with Ollama and Koboldcpp API endpoints)
- [LocalLLM](https://github.com/qusaismael/localllm) (Minimal Web-App to run ollama models on it with a GUI) - [LocalLLM](https://github.com/qusaismael/localllm) (Minimal Web-App to run ollama models on it with a GUI)
- [Ollamazing](https://github.com/buiducnhat/ollamazing) (Web extension to run Ollama models) - [Ollamazing](https://github.com/buiducnhat/ollamazing) (Web extension to run Ollama models)
- [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivent endpoint with Ollama support for running locally) - [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivalent endpoint with Ollama support for running locally)
- [AntSK](https://github.com/AIDotNet/AntSK) (Out-of-the-box & Adaptable RAG Chatbot) - [AntSK](https://github.com/AIDotNet/AntSK) (Out-of-the-box & Adaptable RAG Chatbot)
- [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot) - [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot)
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models) - [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
@@ -399,6 +399,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Ollamb](https://github.com/hengkysteen/ollamb) (Simple yet rich in features, cross-platform built with Flutter and designed for Ollama. Try the [web demo](https://hengkysteen.github.io/demo/ollamb/).) - [Ollamb](https://github.com/hengkysteen/ollamb) (Simple yet rich in features, cross-platform built with Flutter and designed for Ollama. Try the [web demo](https://hengkysteen.github.io/demo/ollamb/).)
- [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama) - [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama)
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable) - [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
### Cloud ### Cloud
@@ -440,7 +441,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama - [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
- [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis. - [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis.
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama. - [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull and download models from Ollama Registry in your terminal. - [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull, and download models from Ollama Registry in your terminal.
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform) - [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
### Apple Vision Pro ### Apple Vision Pro
@@ -515,7 +516,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Swollama for Swift](https://github.com/marcusziade/Swollama) with [DocC](https://marcusziade.github.io/Swollama/documentation/swollama/) - [Swollama for Swift](https://github.com/marcusziade/Swollama) with [DocC](https://marcusziade.github.io/Swollama/documentation/swollama/)
- [GoLamify](https://github.com/prasad89/golamify) - [GoLamify](https://github.com/prasad89/golamify)
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell) - [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in unified API) - [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API)
- [LlmTornado](https://github.com/lofcz/llmtornado) (C# library providing a unified interface for major FOSS & Commercial inference APIs) - [LlmTornado](https://github.com/lofcz/llmtornado) (C# library providing a unified interface for major FOSS & Commercial inference APIs)
- [Ollama for Zig](https://github.com/dravenk/ollama-zig) - [Ollama for Zig](https://github.com/dravenk/ollama-zig)
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider) - [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
@@ -524,11 +525,11 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Mobile ### Mobile
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Lightning-fast Cross-platform AI chat app with native UI for Android, iOS and iPad) - [SwiftChat](https://github.com/aws-samples/swift-chat) (Lightning-fast Cross-platform AI chat app with native UI for Android, iOS, and iPad)
- [Enchanted](https://github.com/AugustDev/enchanted) - [Enchanted](https://github.com/AugustDev/enchanted)
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid) - [Maid](https://github.com/Mobile-Artificial-Intelligence/maid)
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama) - [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption) - [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
- [Ollama Android Chat](https://github.com/sunshine0523/OllamaServer) (No need for Termux, start the Ollama service with one click on an Android device) - [Ollama Android Chat](https://github.com/sunshine0523/OllamaServer) (No need for Termux, start the Ollama service with one click on an Android device)
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.) - [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
@@ -552,7 +553,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Obsidian Local GPT plugin](https://github.com/pfrankov/obsidian-local-gpt) - [Obsidian Local GPT plugin](https://github.com/pfrankov/obsidian-local-gpt)
- [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama) - [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama)
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama) - [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
- [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use ollama as a copilot like Github copilot) - [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use Ollama as a copilot like GitHub Copilot)
- [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama) - [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama)
- [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and Hugging Face) - [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and Hugging Face)
- [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension) - [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension)
@@ -562,8 +563,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation) - [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation)
- [ChatGPTBox: All in one browser extension](https://github.com/josStorer/chatGPTBox) with [Integrating Tutorial](https://github.com/josStorer/chatGPTBox/issues/616#issuecomment-1975186467) - [ChatGPTBox: All in one browser extension](https://github.com/josStorer/chatGPTBox) with [Integrating Tutorial](https://github.com/josStorer/chatGPTBox/issues/616#issuecomment-1975186467)
- [Discord AI chat/moderation bot](https://github.com/rapmd73/Companion) Chat/moderation bot written in python. Uses Ollama to create personalities. - [Discord AI chat/moderation bot](https://github.com/rapmd73/Companion) Chat/moderation bot written in python. Uses Ollama to create personalities.
- [Headless Ollama](https://github.com/nischalj10/headless-ollama) (Scripts to automatically install ollama client & models on any OS for apps that depends on ollama server) - [Headless Ollama](https://github.com/nischalj10/headless-ollama) (Scripts to automatically install ollama client & models on any OS for apps that depend on ollama server)
- [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front end Open WebUI service.) - [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front-end Open WebUI service.)
- [node-red-contrib-ollama](https://github.com/jakubburkiewicz/node-red-contrib-ollama) - [node-red-contrib-ollama](https://github.com/jakubburkiewicz/node-red-contrib-ollama)
- [Local AI Helper](https://github.com/ivostoykov/localAI) (Chrome and Firefox extensions that enable interactions with the active tab and customisable API endpoints. Includes secure storage for user prompts.) - [Local AI Helper](https://github.com/ivostoykov/localAI) (Chrome and Firefox extensions that enable interactions with the active tab and customisable API endpoints. Includes secure storage for user prompts.)
- [vnc-lm](https://github.com/jake83741/vnc-lm) (Discord bot for messaging with LLMs through Ollama and LiteLLM. Seamlessly move between local and flagship models.) - [vnc-lm](https://github.com/jake83741/vnc-lm) (Discord bot for messaging with LLMs through Ollama and LiteLLM. Seamlessly move between local and flagship models.)

View File

@@ -22,7 +22,6 @@ import (
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
"sync"
"sync/atomic" "sync/atomic"
"syscall" "syscall"
"time" "time"
@@ -32,7 +31,6 @@ import (
"github.com/olekukonko/tablewriter" "github.com/olekukonko/tablewriter"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
"golang.org/x/term" "golang.org/x/term"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
@@ -108,7 +106,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
} }
spinner.Stop() spinner.Stop()
req.Model = args[0] req.Name = args[0]
quantize, _ := cmd.Flags().GetString("quantize") quantize, _ := cmd.Flags().GetString("quantize")
if quantize != "" { if quantize != "" {
req.Quantize = quantize req.Quantize = quantize
@@ -119,43 +117,26 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
var mu sync.Mutex if len(req.Files) > 0 {
var g errgroup.Group fileMap := map[string]string{}
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1)) for f, digest := range req.Files {
// copy files since we'll be modifying the map
temp := req.Files
req.Files = make(map[string]string, len(temp))
for f, digest := range temp {
g.Go(func() error {
if _, err := createBlob(cmd, client, f, digest, p); err != nil { if _, err := createBlob(cmd, client, f, digest, p); err != nil {
return err return err
} }
fileMap[filepath.Base(f)] = digest
mu.Lock() }
req.Files[filepath.Base(f)] = digest req.Files = fileMap
mu.Unlock()
return nil
})
} }
// copy files since we'll be modifying the map if len(req.Adapters) > 0 {
temp = req.Adapters fileMap := map[string]string{}
req.Adapters = make(map[string]string, len(temp)) for f, digest := range req.Adapters {
for f, digest := range temp {
g.Go(func() error {
if _, err := createBlob(cmd, client, f, digest, p); err != nil { if _, err := createBlob(cmd, client, f, digest, p); err != nil {
return err return err
} }
fileMap[filepath.Base(f)] = digest
mu.Lock() }
req.Adapters[filepath.Base(f)] = digest req.Adapters = fileMap
mu.Unlock()
return nil
})
}
if err := g.Wait(); err != nil {
return err
} }
bars := make(map[string]*progress.Bar) bars := make(map[string]*progress.Bar)
@@ -232,7 +213,7 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string, digest stri
} }
}() }()
if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil { if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err return "", err
} }
return digest, nil return digest, nil
@@ -1426,7 +1407,6 @@ func NewCLI() *cobra.Command {
envVars["OLLAMA_LLM_LIBRARY"], envVars["OLLAMA_LLM_LIBRARY"],
envVars["OLLAMA_GPU_OVERHEAD"], envVars["OLLAMA_GPU_OVERHEAD"],
envVars["OLLAMA_LOAD_TIMEOUT"], envVars["OLLAMA_LOAD_TIMEOUT"],
envVars["OLLAMA_CONTEXT_LENGTH"],
}) })
default: default:
appendEnvDocs(cmd, envs) appendEnvDocs(cmd, envs)

View File

@@ -690,7 +690,7 @@ func TestCreateHandler(t *testing.T) {
return return
} }
if req.Model != "test-model" { if req.Name != "test-model" {
t.Errorf("expected model name 'test-model', got %s", req.Name) t.Errorf("expected model name 'test-model', got %s", req.Name)
} }

View File

@@ -4,9 +4,10 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"io/fs" "io/fs"
"log/slog" "log/slog"
"os" "slices"
"strings" "strings"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
@@ -84,14 +85,6 @@ func (ModelParameters) specialTokenTypes() []string {
} }
} }
func (ModelParameters) writeFile(f *os.File, kv ggml.KV, ts []ggml.Tensor) error {
return ggml.WriteGGUF(f, kv, ts)
}
func (AdapterParameters) writeFile(f *os.File, kv ggml.KV, ts []ggml.Tensor) error {
return ggml.WriteGGUF(f, kv, ts)
}
type ModelConverter interface { type ModelConverter interface {
// KV maps parameters to LLM key-values // KV maps parameters to LLM key-values
KV(*Tokenizer) ggml.KV KV(*Tokenizer) ggml.KV
@@ -103,8 +96,6 @@ type ModelConverter interface {
// specialTokenTypes returns any special token types the model uses // specialTokenTypes returns any special token types the model uses
specialTokenTypes() []string specialTokenTypes() []string
// writeFile writes the model to the provided io.WriteSeeker
writeFile(*os.File, ggml.KV, []ggml.Tensor) error
} }
type moreParser interface { type moreParser interface {
@@ -119,11 +110,9 @@ type AdapterConverter interface {
// Replacements returns a list of string pairs to replace in tensor names. // Replacements returns a list of string pairs to replace in tensor names.
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details // See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
Replacements() []string Replacements() []string
writeFile(*os.File, ggml.KV, []ggml.Tensor) error
} }
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error { func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV ggml.KV) error {
bts, err := fs.ReadFile(fsys, "adapter_config.json") bts, err := fs.ReadFile(fsys, "adapter_config.json")
if err != nil { if err != nil {
return err return err
@@ -158,14 +147,14 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
return err return err
} }
return conv.writeFile(f, conv.KV(baseKV), conv.Tensors(ts)) return writeFile(ws, conv.KV(baseKV), conv.Tensors(ts))
} }
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations // Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path. // and files it finds in the input path.
// Supported input model formats include safetensors. // Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model. // Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error { func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
bts, err := fs.ReadFile(fsys, "config.json") bts, err := fs.ReadFile(fsys, "config.json")
if err != nil { if err != nil {
return err return err
@@ -184,6 +173,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
switch p.Architectures[0] { switch p.Architectures[0] {
case "LlamaForCausalLM": case "LlamaForCausalLM":
conv = &llamaModel{} conv = &llamaModel{}
case "Llama4ForConditionalGeneration":
conv = &llama4Model{}
case "Mistral3ForConditionalGeneration": case "Mistral3ForConditionalGeneration":
conv = &mistral3Model{} conv = &mistral3Model{}
case "MixtralForCausalLM": case "MixtralForCausalLM":
@@ -248,5 +239,13 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
return err return err
} }
return conv.writeFile(f, conv.KV(t), conv.Tensors(ts)) return writeFile(ws, conv.KV(t), conv.Tensors(ts))
}
func writeFile(ws io.WriteSeeker, kv ggml.KV, ts []ggml.Tensor) error {
for i := range ts {
ts[i].Shape = slices.Clone(ts[i].Shape)
slices.Reverse(ts[i].Shape)
}
return ggml.WriteGGUF(ws, kv, ts)
} }

View File

@@ -42,6 +42,8 @@ type llamaModel struct {
LayerNormEpsilon float32 `json:"layer_norm_epsilon"` LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
NormEpsilon float32 `json:"norm_epsilon"` NormEpsilon float32 `json:"norm_epsilon"`
HeadDim uint32 `json:"head_dim"` HeadDim uint32 `json:"head_dim"`
skipRepack bool
} }
var _ ModelConverter = (*llamaModel)(nil) var _ ModelConverter = (*llamaModel)(nil)
@@ -70,6 +72,10 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
kv["llama.rope.dimension_count"] = p.HiddenSize / headCount kv["llama.rope.dimension_count"] = p.HiddenSize / headCount
} }
if p.HeadDim > 0 {
kv["llama.attention.head_dim"] = p.HeadDim
}
if p.RopeTheta > 0 { if p.RopeTheta > 0 {
kv["llama.rope.freq_base"] = p.RopeTheta kv["llama.rope.freq_base"] = p.RopeTheta
} }
@@ -133,9 +139,10 @@ func (p *llamaModel) Tensors(ts []Tensor) []ggml.Tensor {
} }
for _, t := range ts { for _, t := range ts {
if strings.HasSuffix(t.Name(), "attn_q.weight") || if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") {
strings.HasSuffix(t.Name(), "attn_k.weight") { if !p.skipRepack {
t.SetRepacker(p.repack) t.SetRepacker(p.repack)
}
} }
out = append(out, ggml.Tensor{ out = append(out, ggml.Tensor{

169
convert/convert_llama4.go Normal file
View File

@@ -0,0 +1,169 @@
package convert
import (
"slices"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type llama4Model struct {
ModelParameters
TextModel struct {
llamaModel
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
NumLocalExperts uint32 `json:"num_local_experts"`
InterleaveMOELayerStep uint32 `json:"interleave_moe_layer_step"`
UseQKNorm bool `json:"use_qk_norm"`
IntermediateSizeMLP uint32 `json:"intermediate_size_mlp"`
AttentionChunkSize uint32 `json:"attention_chunk_size"`
} `json:"text_config"`
VisionModel struct {
NumHiddenLayers uint32 `json:"num_hidden_layers"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
ImageSize uint32 `json:"image_size"`
PatchSize uint32 `json:"patch_size"`
RopeTheta float32 `json:"rope_theta"`
NormEpsilon float32 `json:"norm_eps"`
PixelShuffleRatio float32 `json:"pixel_shuffle_ratio"`
} `json:"vision_config"`
}
// KV implements ModelConverter.
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama4"
for k, v := range p.TextModel.KV(t) {
if strings.HasPrefix(k, "llama.") {
kv[strings.ReplaceAll(k, "llama.", "llama4.")] = v
}
}
kv["llama4.feed_forward_length"] = p.TextModel.IntermediateSizeMLP
kv["llama4.expert_feed_forward_length"] = p.TextModel.IntermediateSize
kv["llama4.expert_count"] = p.TextModel.NumLocalExperts
kv["llama4.expert_used_count"] = p.TextModel.NumExpertsPerToken
kv["llama4.interleave_moe_layer_step"] = p.TextModel.InterleaveMOELayerStep
kv["llama4.use_qk_norm"] = p.TextModel.UseQKNorm
kv["llama4.attention.chunk_size"] = p.TextModel.AttentionChunkSize
kv["llama4.vision.block_count"] = p.VisionModel.NumHiddenLayers
kv["llama4.vision.embedding_length"] = p.VisionModel.HiddenSize
kv["llama4.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
kv["llama4.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
kv["llama4.vision.image_size"] = p.VisionModel.ImageSize
kv["llama4.vision.patch_size"] = p.VisionModel.PatchSize
kv["llama4.vision.rope.freq_base"] = p.VisionModel.RopeTheta
kv["llama4.vision.layer_norm_epsilon"] = p.VisionModel.NormEpsilon
kv["llama4.vision.pixel_shuffle_ratio"] = p.VisionModel.PixelShuffleRatio
return kv
}
// Replacements implements ModelConverter.
func (p *llama4Model) Replacements() []string {
return append(
p.TextModel.Replacements(),
"language_model.", "",
"vision_model", "v",
"multi_modal_projector", "mm",
"feed_forward.down_proj", "ffn_down",
"feed_forward.up_proj", "ffn_up",
"feed_forward.gate_proj", "ffn_gate",
"feed_forward.", "ffn_",
"shared_expert.down_proj", "down_shexp",
"shared_expert.gate_proj", "gate_shexp",
"shared_expert.up_proj", "up_shexp",
"experts.down_proj", "down_exps.weight",
"experts.gate_up_proj", "gate_up_exps.weight",
"router", "gate_inp",
"patch_embedding.linear", "patch_embedding",
)
}
// Tensors implements ModelConverter.
func (p *llama4Model) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
var textTensors []Tensor
for _, t := range ts {
if strings.HasPrefix(t.Name(), "v.") || strings.HasPrefix(t.Name(), "mm.") {
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
} else if strings.Contains(t.Name(), "ffn_gate_up_exps") {
// gate and up projectors are fused
// dims[1], dims[2] must be swapped
// [experts, hidden_size, intermediate_size * 2] --> [experts, intermediate_size, hidden_size]
halfDim := int(t.Shape()[2]) / 2
newShape := slices.Clone(t.Shape())
newShape[1], newShape[2] = newShape[2]/2, newShape[1]
for i, name := range []string{"ffn_gate_exps", "ffn_up_exps"} {
// clone tensor since we need separate repackers
tt := t.Clone()
tt.SetRepacker(p.repack(nil, nil, tensor.S(i*halfDim, (i+1)*halfDim)))
out = append(out, ggml.Tensor{
Name: strings.ReplaceAll(tt.Name(), "ffn_gate_up_exps", name),
Kind: tt.Kind(),
Shape: newShape,
WriterTo: tt,
})
}
} else if strings.Contains(t.Name(), "ffn_down_exps") {
// dims[1], dims[2] must be swapped
// [experts, intermediate_size, hidden_size] --> [experts, hidden_size, intermediate_size]
t.SetRepacker(p.repack())
newShape := slices.Clone(t.Shape())
newShape[1], newShape[2] = newShape[2], newShape[1]
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: newShape,
WriterTo: t,
})
} else {
textTensors = append(textTensors, t)
}
}
p.TextModel.skipRepack = true
out = append(out, p.TextModel.Tensors(textTensors)...)
return out
}
func (p *llama4Model) repack(slice ...tensor.Slice) Repacker {
return func(name string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i, dim := range shape {
dims[i] = int(dim)
}
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
t, err := t.Slice(slice...)
if err != nil {
return nil, err
}
if err := t.T(0, 2, 1); err != nil {
return nil, err
}
t = tensor.Materialize(t)
// flatten tensor so it can be return as a vector
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(t.(*tensor.Dense))
}
}

View File

@@ -11,7 +11,6 @@ import (
"io" "io"
"io/fs" "io/fs"
"log/slog" "log/slog"
"math"
"os" "os"
"path/filepath" "path/filepath"
"slices" "slices"
@@ -48,7 +47,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
} }
t.Cleanup(func() { r.Close() }) t.Cleanup(func() { r.Close() })
m, _, err := ggml.Decode(r, math.MaxInt) m, _, err := ggml.Decode(r, -1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -332,7 +331,7 @@ func TestConvertAdapter(t *testing.T) {
} }
defer r.Close() defer r.Close()
m, _, err := ggml.Decode(r, math.MaxInt) m, _, err := ggml.Decode(r, -1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -11,14 +11,15 @@ type Tensor interface {
Name() string Name() string
Shape() []uint64 Shape() []uint64
Kind() uint32 Kind() uint32
SetRepacker(repacker) SetRepacker(Repacker)
WriteTo(io.Writer) (int64, error) WriteTo(io.Writer) (int64, error)
Clone() Tensor
} }
type tensorBase struct { type tensorBase struct {
name string name string
shape []uint64 shape []uint64
repacker repacker Repacker
} }
func (t tensorBase) Name() string { func (t tensorBase) Name() string {
@@ -36,7 +37,8 @@ const (
func (t tensorBase) Kind() uint32 { func (t tensorBase) Kind() uint32 {
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") || if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
t.name == "token_types.weight" { t.name == "token_types.weight" ||
t.name == "v.positional_embedding_vlm" {
// these tensors are always F32 // these tensors are always F32
return 0 return 0
} }
@@ -51,11 +53,11 @@ func (t tensorBase) Kind() uint32 {
} }
} }
func (t *tensorBase) SetRepacker(fn repacker) { func (t *tensorBase) SetRepacker(fn Repacker) {
t.repacker = fn t.repacker = fn
} }
type repacker func(string, []float32, []uint64) ([]float32, error) type Repacker func(string, []float32, []uint64) ([]float32, error)
func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) { func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) {
patterns := []struct { patterns := []struct {

View File

@@ -94,6 +94,21 @@ type safetensor struct {
*tensorBase *tensorBase
} }
func (st safetensor) Clone() Tensor {
return &safetensor{
fs: st.fs,
path: st.path,
dtype: st.dtype,
offset: st.offset,
size: st.size,
tensorBase: &tensorBase{
name: st.name,
repacker: st.repacker,
shape: slices.Clone(st.shape),
},
}
}
func (st safetensor) WriteTo(w io.Writer) (int64, error) { func (st safetensor) WriteTo(w io.Writer) (int64, error) {
f, err := st.fs.Open(st.path) f, err := st.fs.Open(st.path)
if err != nil { if err != nil {

View File

@@ -43,6 +43,17 @@ type torch struct {
*tensorBase *tensorBase
} }
func (t torch) Clone() Tensor {
return torch{
storage: t.storage,
tensorBase: &tensorBase{
name: t.name,
shape: t.shape,
repacker: t.repacker,
},
}
}
func (pt torch) WriteTo(w io.Writer) (int64, error) { func (pt torch) WriteTo(w io.Writer) (int64, error) {
return 0, nil return 0, nil
} }

View File

@@ -20,7 +20,7 @@ Please refer to the [GPU docs](./gpu.md).
## How can I specify the context window size? ## How can I specify the context window size?
By default, Ollama uses a context window size of 4096 tokens, unless you have a single GPU with <= 4 GB of VRAM, in which case it will default to 2048 tokens. By default, Ollama uses a context window size of 4096 tokens.
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use: This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
@@ -31,7 +31,7 @@ OLLAMA_CONTEXT_LENGTH=8192 ollama serve
To change this when using `ollama run`, use `/set parameter`: To change this when using `ollama run`, use `/set parameter`:
```shell ```shell
/set parameter num_ctx 8192 /set parameter num_ctx 4096
``` ```
When using the API, specify the `num_ctx` parameter: When using the API, specify the `num_ctx` parameter:
@@ -41,7 +41,7 @@ curl http://localhost:11434/api/generate -d '{
"model": "llama3.2", "model": "llama3.2",
"prompt": "Why is the sky blue?", "prompt": "Why is the sky blue?",
"options": { "options": {
"num_ctx": 8192 "num_ctx": 4096
} }
}' }'
``` ```

View File

@@ -169,7 +169,7 @@ var (
// Enable the new Ollama engine // Enable the new Ollama engine
NewEngine = Bool("OLLAMA_NEW_ENGINE") NewEngine = Bool("OLLAMA_NEW_ENGINE")
// ContextLength sets the default context length // ContextLength sets the default context length
ContextLength = Int64("OLLAMA_CONTEXT_LENGTH", -1) ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
) )
func String(s string) func() string { func String(s string) func() string {
@@ -227,20 +227,6 @@ func Uint64(key string, defaultValue uint64) func() uint64 {
} }
} }
func Int64(key string, defaultValue int64) func() int64 {
return func() int64 {
if s := Var(key); s != "" {
if n, err := strconv.ParseInt(s, 10, 64); err != nil {
slog.Warn("invalid environment variable, using default", "key", key, "value", s, "default", defaultValue)
} else {
return n
}
}
return defaultValue
}
}
// Set aside VRAM per GPU // Set aside VRAM per GPU
var GpuOverhead = Uint64("OLLAMA_GPU_OVERHEAD", 0) var GpuOverhead = Uint64("OLLAMA_GPU_OVERHEAD", 0)
@@ -269,7 +255,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"}, "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"}, "OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default 4096 or 2048 with low VRAM)"}, "OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"}, "OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
// Informational // Informational

View File

@@ -278,9 +278,9 @@ func TestVar(t *testing.T) {
} }
func TestContextLength(t *testing.T) { func TestContextLength(t *testing.T) {
cases := map[string]int64{ cases := map[string]uint{
"": -1, "": 4096,
"4096": 4096, "2048": 2048,
} }
for k, v := range cases { for k, v := range cases {

View File

@@ -8,6 +8,6 @@ type Config interface {
Bool(string, ...bool) bool Bool(string, ...bool) bool
Strings(string, ...[]string) []string Strings(string, ...[]string) []string
Uints(string, ...[]uint32) []uint32 Ints(string, ...[]int32) []int32
Floats(string, ...[]float32) []float32 Floats(string, ...[]float32) []float32
} }

View File

@@ -33,7 +33,7 @@ func (kv KV) Kind() string {
} }
func (kv KV) ParameterCount() uint64 { func (kv KV) ParameterCount() uint64 {
return keyValue[uint64](kv, "general.parameter_count") return keyValue(kv, "general.parameter_count", uint64(0))
} }
func (kv KV) FileType() fileType { func (kv KV) FileType() fileType {
@@ -105,42 +105,42 @@ func (kv KV) Bool(key string, defaultValue ...bool) bool {
} }
func (kv KV) Strings(key string, defaultValue ...[]string) []string { func (kv KV) Strings(key string, defaultValue ...[]string) []string {
r := keyValue(kv, key, &array{}) return keyValue(kv, key, &array[string]{values: append(defaultValue, []string(nil))[0]}).values
s := make([]string, r.size) }
for i := range r.size {
s[i] = r.values[i].(string)
}
return s func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
return keyValue(kv, key, &array[int32]{values: append(defaultValue, []int32(nil))[0]}).values
} }
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 { func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
r := keyValue(kv, key, &array{}) return keyValue(kv, key, &array[uint32]{values: append(defaultValue, []uint32(nil))[0]}).values
s := make([]uint32, r.size)
for i := range r.size {
s[i] = uint32(r.values[i].(int32))
}
return s
} }
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 { func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
r := keyValue(kv, key, &array{}) return keyValue(kv, key, &array[float32]{values: append(defaultValue, []float32(nil))[0]}).values
s := make([]float32, r.size)
for i := range r.size {
s[i] = float32(r.values[i].(float32))
}
return s
} }
func (kv KV) OllamaEngineRequired() bool { func (kv KV) OllamaEngineRequired() bool {
return slices.Contains([]string{ return slices.Contains([]string{
"gemma3", "gemma3",
"mistral3", "mistral3",
"llama4",
}, kv.Architecture()) }, kv.Architecture())
} }
func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T { type valueTypes interface {
uint8 | int8 | uint16 | int16 |
uint32 | int32 | uint64 | int64 |
string | float32 | float64 | bool
}
type arrayValueTypes interface {
*array[uint8] | *array[int8] | *array[uint16] | *array[int16] |
*array[uint32] | *array[int32] | *array[uint64] | *array[int64] |
*array[string] | *array[float32] | *array[float64] | *array[bool]
}
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) T {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") { if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key key = kv.Architecture() + "." + key
} }
@@ -375,13 +375,8 @@ func DetectContentType(b []byte) string {
// Decode decodes a GGML model from the given reader. // Decode decodes a GGML model from the given reader.
// //
// It collects array values for arrays with a size less than or equal to // It collects array values for arrays with a size less than or equal to
// maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If // maxArraySize. If the maxArraySize is negative, all arrays are collected.
// the maxArraySize is negative, all arrays are collected.
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) { func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
if maxArraySize == 0 {
maxArraySize = 1024
}
rs = bufioutil.NewBufferedSeeker(rs, 32<<10) rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
var magic uint32 var magic uint32
@@ -420,7 +415,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
embedding := f.KV().EmbeddingLength() embedding := f.KV().EmbeddingLength()
heads := f.KV().HeadCount() heads := f.KV().HeadCount()
headsKV := f.KV().HeadCountKV() headsKV := f.KV().HeadCountKV()
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array).size) vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size)
embeddingHeads := f.KV().EmbeddingHeadCount() embeddingHeads := f.KV().EmbeddingHeadCount()
embeddingHeadsK := f.KV().EmbeddingHeadCountK() embeddingHeadsK := f.KV().EmbeddingHeadCountK()
@@ -435,7 +430,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
} }
switch f.KV().Architecture() { switch f.KV().Architecture() {
case "llama": case "llama", "llama4":
fullOffload = max( fullOffload = max(
4*batch*(1+4*embedding+context*(1+heads)), 4*batch*(1+4*embedding+context*(1+heads)),
4*batch*(embedding+vocab), 4*batch*(embedding+vocab),
@@ -449,7 +444,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok { if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
// mixtral 8x22b // mixtral 8x22b
ff := uint64(f.KV()["llama.feed_forward_length"].(uint32)) ff := uint64(f.KV().Uint("feed_forward_length"))
partialOffload = max( partialOffload = max(
3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV), 3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch), 4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
@@ -466,9 +461,9 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
case "mllama": case "mllama":
var visionTokens, tiles uint64 = 1601, 4 var visionTokens, tiles uint64 = 1601, 4
crossAttentionLayers := f.KV().Uints("attention.cross_attention_layers") crossAttentionLayers := f.KV().Ints("attention.cross_attention_layers")
for i := range kv { for i := range kv {
if slices.Contains(crossAttentionLayers, uint32(i)) { if slices.Contains(crossAttentionLayers, int32(i)) {
kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) * kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) *
4 * // sizeof(float32) 4 * // sizeof(float32)
visionTokens * visionTokens *
@@ -645,6 +640,9 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
graphSize = 4 * (imageSize*imageSize*numChannels + graphSize = 4 * (imageSize*imageSize*numChannels +
embeddingLength*patchSize + embeddingLength*patchSize +
numPatches*numPatches*headCount) numPatches*numPatches*headCount)
case "llama4":
// vision graph is computed independently in the same schedule
// and is negligible compared to the worst case text graph
} }
return weights, graphSize return weights, graphSize

View File

@@ -2,6 +2,7 @@ package ggml
import ( import (
"maps" "maps"
"math"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
@@ -210,3 +211,61 @@ func TestTensorTypes(t *testing.T) {
}) })
} }
} }
func TestKeyValue(t *testing.T) {
kv := KV{
"general.architecture": "test",
"test.strings": &array[string]{size: 3, values: []string{"a", "b", "c"}},
"test.float32s": &array[float32]{size: 3, values: []float32{1.0, 2.0, 3.0}},
"test.int32s": &array[int32]{size: 3, values: []int32{1, 2, 3}},
"test.uint32s": &array[uint32]{size: 3, values: []uint32{1, 2, 3}},
}
if diff := cmp.Diff(kv.Strings("strings"), []string{"a", "b", "c"}); diff != "" {
t.Errorf("unexpected strings (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Strings("nonexistent.strings"), []string(nil)); diff != "" {
t.Errorf("unexpected strings (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Strings("default.strings", []string{"ollama"}), []string{"ollama"}); diff != "" {
t.Errorf("unexpected strings (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Floats("float32s"), []float32{1.0, 2.0, 3.0}); diff != "" {
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Floats("nonexistent.float32s"), []float32(nil)); diff != "" {
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Floats("default.float32s", []float32{math.MaxFloat32}), []float32{math.MaxFloat32}); diff != "" {
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Ints("int32s"), []int32{1, 2, 3}); diff != "" {
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Ints("nonexistent.int32s"), []int32(nil)); diff != "" {
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Ints("default.int32s", []int32{math.MaxInt32}), []int32{math.MaxInt32}); diff != "" {
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Uints("uint32s"), []uint32{1, 2, 3}); diff != "" {
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Uints("nonexistent.uint32s"), []uint32(nil)); diff != "" {
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Uints("default.uint32s", []uint32{math.MaxUint32}), []uint32{math.MaxUint32}); diff != "" {
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
}
}

View File

@@ -9,12 +9,8 @@ import (
"io" "io"
"log/slog" "log/slog"
"maps" "maps"
"os"
"runtime"
"slices" "slices"
"strings" "strings"
"golang.org/x/sync/errgroup"
) )
type containerGGUF struct { type containerGGUF struct {
@@ -40,10 +36,6 @@ type containerGGUF struct {
maxArraySize int maxArraySize int
} }
func (c *containerGGUF) canCollectArray(size int) bool {
return c.maxArraySize < 0 || size <= c.maxArraySize
}
func (c *containerGGUF) Name() string { func (c *containerGGUF) Name() string {
return "gguf" return "gguf"
} }
@@ -299,6 +291,23 @@ func readGGUFV1String(llm *gguf, r io.Reader) (string, error) {
return b.String(), nil return b.String(), nil
} }
func readGGUFV1StringsData(llm *gguf, r io.Reader, a *array[string]) (any, error) {
for i := range a.size {
if a.values != nil {
e, err := readGGUFV1String(llm, r)
if err != nil {
return nil, err
}
a.values[i] = e
} else {
discardGGUFString(llm, r)
}
}
return a, nil
}
func discardGGUFString(llm *gguf, r io.Reader) error { func discardGGUFString(llm *gguf, r io.Reader) error {
buf := llm.scratch[:8] buf := llm.scratch[:8]
_, err := io.ReadFull(r, buf) _, err := io.ReadFull(r, buf)
@@ -356,78 +365,44 @@ func writeGGUFString(w io.Writer, s string) error {
return err return err
} }
type array struct { func readGGUFStringsData(llm *gguf, r io.Reader, a *array[string]) (any, error) {
size int for i := range a.size {
values []any
}
func (a *array) MarshalJSON() ([]byte, error) {
return json.Marshal(a.values)
}
func readGGUFV1Array(llm *gguf, r io.Reader) (*array, error) {
t, err := readGGUF[uint32](llm, r)
if err != nil {
return nil, err
}
n, err := readGGUF[uint32](llm, r)
if err != nil {
return nil, err
}
a := &array{size: int(n)}
if llm.canCollectArray(int(n)) {
a.values = make([]any, 0, int(n))
}
for i := range n {
var e any
switch t {
case ggufTypeUint8:
e, err = readGGUF[uint8](llm, r)
case ggufTypeInt8:
e, err = readGGUF[int8](llm, r)
case ggufTypeUint16:
e, err = readGGUF[uint16](llm, r)
case ggufTypeInt16:
e, err = readGGUF[int16](llm, r)
case ggufTypeUint32:
e, err = readGGUF[uint32](llm, r)
case ggufTypeInt32:
e, err = readGGUF[int32](llm, r)
case ggufTypeUint64:
e, err = readGGUF[uint64](llm, r)
case ggufTypeInt64:
e, err = readGGUF[int64](llm, r)
case ggufTypeFloat32:
e, err = readGGUF[float32](llm, r)
case ggufTypeFloat64:
e, err = readGGUF[float64](llm, r)
case ggufTypeBool:
e, err = readGGUF[bool](llm, r)
case ggufTypeString:
e, err = readGGUFV1String(llm, r)
default:
return nil, fmt.Errorf("invalid array type: %d", t)
}
if err != nil {
return nil, err
}
if a.values != nil { if a.values != nil {
e, err := readGGUFString(llm, r)
if err != nil {
return nil, err
}
a.values[i] = e a.values[i] = e
} else {
discardGGUFString(llm, r)
} }
} }
return a, nil return a, nil
} }
func readGGUFArray(llm *gguf, r io.Reader) (*array, error) { type array[T any] struct {
if llm.Version == 1 { // size is the actual size of the array
return readGGUFV1Array(llm, r) size int
}
// values is the array of values. this is nil if the array is larger than configured maxSize
values []T
}
func (a *array[T]) MarshalJSON() ([]byte, error) {
return json.Marshal(a.values)
}
func newArray[T any](size, maxSize int) *array[T] {
a := array[T]{size: size}
if maxSize < 0 || size <= maxSize {
a.values = make([]T, size)
}
return &a
}
func readGGUFArray(llm *gguf, r io.Reader) (any, error) {
t, err := readGGUF[uint32](llm, r) t, err := readGGUF[uint32](llm, r)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -438,45 +413,55 @@ func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
return nil, err return nil, err
} }
a := &array{size: int(n)} switch t {
if llm.canCollectArray(int(n)) { case ggufTypeUint8:
a.values = make([]any, int(n)) a := newArray[uint8](int(n), llm.maxArraySize)
} return readGGUFArrayData(llm, r, a)
case ggufTypeInt8:
for i := range n { a := newArray[int8](int(n), llm.maxArraySize)
var e any return readGGUFArrayData(llm, r, a)
switch t { case ggufTypeUint16:
case ggufTypeUint8: a := newArray[uint16](int(n), llm.maxArraySize)
e, err = readGGUF[uint8](llm, r) return readGGUFArrayData(llm, r, a)
case ggufTypeInt8: case ggufTypeInt16:
e, err = readGGUF[int8](llm, r) a := newArray[int16](int(n), llm.maxArraySize)
case ggufTypeUint16: return readGGUFArrayData(llm, r, a)
e, err = readGGUF[uint16](llm, r) case ggufTypeUint32:
case ggufTypeInt16: a := newArray[uint32](int(n), llm.maxArraySize)
e, err = readGGUF[int16](llm, r) return readGGUFArrayData(llm, r, a)
case ggufTypeUint32: case ggufTypeInt32:
e, err = readGGUF[uint32](llm, r) a := newArray[int32](int(n), llm.maxArraySize)
case ggufTypeInt32: return readGGUFArrayData(llm, r, a)
e, err = readGGUF[int32](llm, r) case ggufTypeUint64:
case ggufTypeUint64: a := newArray[uint64](int(n), llm.maxArraySize)
e, err = readGGUF[uint64](llm, r) return readGGUFArrayData(llm, r, a)
case ggufTypeInt64: case ggufTypeInt64:
e, err = readGGUF[int64](llm, r) a := newArray[int64](int(n), llm.maxArraySize)
case ggufTypeFloat32: return readGGUFArrayData(llm, r, a)
e, err = readGGUF[float32](llm, r) case ggufTypeFloat32:
case ggufTypeFloat64: a := newArray[float32](int(n), llm.maxArraySize)
e, err = readGGUF[float64](llm, r) return readGGUFArrayData(llm, r, a)
case ggufTypeBool: case ggufTypeFloat64:
e, err = readGGUF[bool](llm, r) a := newArray[float64](int(n), llm.maxArraySize)
case ggufTypeString: return readGGUFArrayData(llm, r, a)
if a.values != nil { case ggufTypeBool:
e, err = readGGUFString(llm, r) a := newArray[bool](int(n), llm.maxArraySize)
} else { return readGGUFArrayData(llm, r, a)
err = discardGGUFString(llm, r) case ggufTypeString:
} a := newArray[string](int(n), llm.maxArraySize)
default: if llm.Version == 1 {
return nil, fmt.Errorf("invalid array type: %d", t) return readGGUFV1StringsData(llm, r, a)
} }
return readGGUFStringsData(llm, r, a)
default:
return nil, fmt.Errorf("invalid array type: %d", t)
}
}
func readGGUFArrayData[T any](llm *gguf, r io.Reader, a *array[T]) (any, error) {
for i := range a.size {
e, err := readGGUF[T](llm, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -506,22 +491,22 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
return binary.Write(w, binary.LittleEndian, s) return binary.Write(w, binary.LittleEndian, s)
} }
func WriteGGUF(f *os.File, kv KV, ts []Tensor) error { func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
alignment := kv.Uint("general.alignment", 32) alignment := kv.Uint("general.alignment", 32)
if err := binary.Write(f, binary.LittleEndian, []byte("GGUF")); err != nil { if err := binary.Write(ws, binary.LittleEndian, []byte("GGUF")); err != nil {
return err return err
} }
if err := binary.Write(f, binary.LittleEndian, uint32(3)); err != nil { if err := binary.Write(ws, binary.LittleEndian, uint32(3)); err != nil {
return err return err
} }
if err := binary.Write(f, binary.LittleEndian, uint64(len(ts))); err != nil { if err := binary.Write(ws, binary.LittleEndian, uint64(len(ts))); err != nil {
return err return err
} }
if err := binary.Write(f, binary.LittleEndian, uint64(len(kv))); err != nil { if err := binary.Write(ws, binary.LittleEndian, uint64(len(kv))); err != nil {
return err return err
} }
@@ -529,7 +514,7 @@ func WriteGGUF(f *os.File, kv KV, ts []Tensor) error {
slices.Sort(keys) slices.Sort(keys)
for _, key := range keys { for _, key := range keys {
if err := ggufWriteKV(f, key, kv[key]); err != nil { if err := ggufWriteKV(ws, key, kv[key]); err != nil {
return err return err
} }
} }
@@ -545,34 +530,21 @@ func WriteGGUF(f *os.File, kv KV, ts []Tensor) error {
}) })
var s uint64 var s uint64
for i := range ts { for _, t := range ts {
ts[i].Offset = s + uint64(ggufPadding(int64(s), int64(alignment))) t.Offset = s + uint64(ggufPadding(int64(s), int64(alignment)))
if err := ggufWriteTensorInfo(f, ts[i]); err != nil { if err := ggufWriteTensorInfo(ws, t); err != nil {
return err return err
} }
s += ts[i].Size() s += t.Size()
} }
offset, err := f.Seek(0, io.SeekCurrent)
if err != nil {
return err
}
offset += ggufPadding(offset, int64(alignment))
slog.Debug("gguf", "offset", offset, "size", s, "alignment", alignment)
var g errgroup.Group
g.SetLimit(runtime.GOMAXPROCS(0))
for _, t := range ts { for _, t := range ts {
t := t if err := ggufWriteTensor(ws, t, int64(alignment)); err != nil {
w := io.NewOffsetWriter(f, offset+int64(t.Offset))
g.Go(func() error {
_, err := t.WriteTo(w)
return err return err
}) }
} }
return g.Wait() return nil
} }
func ggufWriteKV(ws io.WriteSeeker, k string, v any) error { func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
@@ -644,8 +616,8 @@ func ggufWriteTensorInfo(ws io.WriteSeeker, t Tensor) error {
return err return err
} }
for i := range len(t.Shape) { for _, n := range t.Shape {
if err := binary.Write(ws, binary.LittleEndian, t.Shape[len(t.Shape)-i-1]); err != nil { if err := binary.Write(ws, binary.LittleEndian, n); err != nil {
return err return err
} }
} }
@@ -657,6 +629,20 @@ func ggufWriteTensorInfo(ws io.WriteSeeker, t Tensor) error {
return binary.Write(ws, binary.LittleEndian, t.Offset) return binary.Write(ws, binary.LittleEndian, t.Offset)
} }
func ggufWriteTensor(ws io.WriteSeeker, t Tensor, alignment int64) error {
offset, err := ws.Seek(0, io.SeekCurrent)
if err != nil {
return err
}
if err := binary.Write(ws, binary.LittleEndian, bytes.Repeat([]byte{0}, int(ggufPadding(offset, alignment)))); err != nil {
return err
}
_, err = t.WriteTo(ws)
return err
}
func ggufPadding(offset, align int64) int64 { func ggufPadding(offset, align int64) int64 {
return (align - offset%align) % align return (align - offset%align) % align
} }

12
go.mod
View File

@@ -11,7 +11,7 @@ require (
github.com/spf13/cobra v1.7.0 github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
github.com/x448/float16 v0.8.4 github.com/x448/float16 v0.8.4
golang.org/x/sync v0.11.0 golang.org/x/sync v0.12.0
) )
require ( require (
@@ -70,12 +70,12 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.33.0 golang.org/x/crypto v0.36.0
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
golang.org/x/net v0.35.0 // indirect golang.org/x/net v0.38.0 // indirect
golang.org/x/sys v0.30.0 golang.org/x/sys v0.31.0
golang.org/x/term v0.29.0 golang.org/x/term v0.30.0
golang.org/x/text v0.22.0 golang.org/x/text v0.23.0
google.golang.org/protobuf v1.34.1 google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

24
go.sum
View File

@@ -214,8 +214,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@@ -257,8 +257,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -268,8 +268,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -285,17 +285,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU= golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s= golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@@ -34,13 +34,15 @@ func cosineSimilarity[V float32 | float64](v1, v2 []V) V {
func TestAllMiniLMEmbeddings(t *testing.T) { func TestAllMiniLMEmbeddings(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel() defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req := api.EmbeddingRequest{ req := api.EmbeddingRequest{
Model: "all-minilm", Model: "all-minilm",
Prompt: "why is the sky blue?", Prompt: "why is the sky blue?",
} }
res, err := embeddingTestHelper(ctx, t, req) res, err := embeddingTestHelper(ctx, client, t, req)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
@@ -62,13 +64,15 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
func TestAllMiniLMEmbed(t *testing.T) { func TestAllMiniLMEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel() defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req := api.EmbedRequest{ req := api.EmbedRequest{
Model: "all-minilm", Model: "all-minilm",
Input: "why is the sky blue?", Input: "why is the sky blue?",
} }
res, err := embedTestHelper(ctx, t, req) res, err := embedTestHelper(ctx, client, t, req)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
@@ -98,13 +102,15 @@ func TestAllMiniLMEmbed(t *testing.T) {
func TestAllMiniLMBatchEmbed(t *testing.T) { func TestAllMiniLMBatchEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel() defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req := api.EmbedRequest{ req := api.EmbedRequest{
Model: "all-minilm", Model: "all-minilm",
Input: []string{"why is the sky blue?", "why is the grass green?"}, Input: []string{"why is the sky blue?", "why is the grass green?"},
} }
res, err := embedTestHelper(ctx, t, req) res, err := embedTestHelper(ctx, client, t, req)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
@@ -144,6 +150,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
func TestAllMiniLMEmbedTruncate(t *testing.T) { func TestAllMiniLMEmbedTruncate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel() defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
truncTrue, truncFalse := true, false truncTrue, truncFalse := true, false
@@ -182,7 +190,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
res := make(map[string]*api.EmbedResponse) res := make(map[string]*api.EmbedResponse)
for _, req := range reqs { for _, req := range reqs {
response, err := embedTestHelper(ctx, t, req.Request) response, err := embedTestHelper(ctx, client, t, req.Request)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
} }
@@ -198,7 +206,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
} }
// check that truncate set to false returns an error if context length is exceeded // check that truncate set to false returns an error if context length is exceeded
_, err := embedTestHelper(ctx, t, api.EmbedRequest{ _, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
Model: "all-minilm", Model: "all-minilm",
Input: "why is the sky blue?", Input: "why is the sky blue?",
Truncate: &truncFalse, Truncate: &truncFalse,
@@ -210,9 +218,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
} }
} }
func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) { func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil { if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err) t.Fatalf("failed to pull model %s: %v", req.Model, err)
} }
@@ -226,9 +232,7 @@ func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingReq
return response, nil return response, nil
} }
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) { func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil { if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err) t.Fatalf("failed to pull model %s: %v", req.Model, err)
} }

View File

@@ -21,6 +21,7 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
type Causal struct { type Causal struct {
DType ml.DType DType ml.DType
windowSize int32 windowSize int32
chunkSize int32
opts CausalOptions opts CausalOptions
@@ -97,6 +98,17 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
} }
} }
func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
return &Causal{
windowSize: math.MaxInt32,
chunkSize: chunkSize,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
}
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
if c.config == nil { if c.config == nil {
var config ml.CacheConfig var config ml.CacheConfig
@@ -300,6 +312,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) || if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
(enabled && c.cells[j].pos > c.curPositions[i]) || (enabled && c.cells[j].pos > c.curPositions[i]) ||
c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
c.cells[j].pos < c.curPositions[i]-c.windowSize { c.cells[j].pos < c.curPositions[i]-c.windowSize {
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1)) mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
} }

View File

@@ -86,6 +86,64 @@ func TestSWA(t *testing.T) {
testCache(t, backend, cache, tests) testCache(t, backend, cache, tests)
} }
func TestChunkedAttention(t *testing.T) {
cache := NewChunkedAttentionCache(2, nil)
defer cache.Close()
var b testBackend
cache.Init(&b, ml.DTypeF16, 1, 16, 16)
x := float32(math.Inf(-1))
testCache(
t, &b, cache,
[]testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{
0, x, x, x,
0, 0, x, x,
x, x, 0, x,
x, x, 0, 0,
},
},
{
name: "SecondBatch",
in: []float32{5, 6, 7},
inShape: []int{1, 1, 3},
seqs: []int{0, 0, 0},
pos: []int32{4, 5, 6},
expected: []float32{1, 2, 3, 4, 5, 6, 7},
expectedShape: []int{1, 1, 7},
expectedMask: []float32{
x, x, x, x, 0, x, x,
x, x, x, x, 0, 0, x,
x, x, x, x, x, x, 0,
},
},
{
name: "ThirdBatch",
in: []float32{8, 9},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{7, 8},
expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
expectedShape: []int{1, 1, 9},
expectedMask: []float32{
x, x, x, x, x, x, 0, 0, x,
x, x, x, x, x, x, x, x, 0,
},
},
},
)
}
func TestSequences(t *testing.T) { func TestSequences(t *testing.T) {
backend := &testBackend{} backend := &testBackend{}
cache := NewCausalCache(nil) cache := NewCausalCache(nil)
@@ -293,8 +351,16 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
context.Forward(out, mask).Compute(out, mask) context.Forward(out, mask).Compute(out, mask)
if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) { if !slices.Equal(out.Floats(), test.expected) {
t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask) t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected)
}
if !slices.Equal(out.Shape(), test.expectedShape) {
t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape)
}
if !slices.Equal(mask.Floats(), test.expectedMask) {
t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask)
} }
}) })
} }

View File

@@ -414,7 +414,7 @@ func projectorMemoryRequirements(filename string) (weights, graphSize uint64) {
} }
defer file.Close() defer file.Close()
ggml, _, err := ggml.Decode(file, 0) ggml, _, err := ggml.Decode(file, 1024)
if err != nil { if err != nil {
return 0, 0 return 0, 0
} }

View File

@@ -329,11 +329,13 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
libraryPaths = append(libraryPaths, filepath.SplitList(libraryPath)...) libraryPaths = append(libraryPaths, filepath.SplitList(libraryPath)...)
} }
ggmlPaths := []string{discover.LibOllamaPath}
if len(compatible) > 0 { if len(compatible) > 0 {
c := compatible[0] c := compatible[0]
if libpath, ok := libs[c]; ok { if libpath, ok := libs[c]; ok {
slog.Debug("adding gpu library", "path", libpath) slog.Debug("adding gpu library", "path", libpath)
libraryPaths = append(libraryPaths, libpath) libraryPaths = append(libraryPaths, libpath)
ggmlPaths = append(ggmlPaths, libpath)
} }
} }
@@ -369,6 +371,8 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
s.cmd.Stderr = s.status s.cmd.Stderr = s.status
s.cmd.SysProcAttr = LlamaServerSysProcAttr s.cmd.SysProcAttr = LlamaServerSysProcAttr
s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ggmlPaths, string(filepath.ListSeparator)))
envWorkarounds := [][2]string{} envWorkarounds := [][2]string{}
for _, gpu := range gpus { for _, gpu := range gpus {
envWorkarounds = append(envWorkarounds, gpu.EnvWorkarounds...) envWorkarounds = append(envWorkarounds, gpu.EnvWorkarounds...)
@@ -406,7 +410,8 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
if envconfig.Debug() { if envconfig.Debug() {
filteredEnv := []string{} filteredEnv := []string{}
for _, ev := range s.cmd.Env { for _, ev := range s.cmd.Env {
if strings.HasPrefix(ev, "CUDA_") || if strings.HasPrefix(ev, "OLLAMA_") ||
strings.HasPrefix(ev, "CUDA_") ||
strings.HasPrefix(ev, "ROCR_") || strings.HasPrefix(ev, "ROCR_") ||
strings.HasPrefix(ev, "ROCM_") || strings.HasPrefix(ev, "ROCM_") ||
strings.HasPrefix(ev, "HIP_") || strings.HasPrefix(ev, "HIP_") ||

View File

@@ -133,6 +133,7 @@ type Tensor interface {
Mul(ctx Context, t2 Tensor) Tensor Mul(ctx Context, t2 Tensor) Tensor
Mulmat(ctx Context, t2 Tensor) Tensor Mulmat(ctx Context, t2 Tensor) Tensor
MulmatFullPrec(ctx Context, t2 Tensor) Tensor MulmatFullPrec(ctx Context, t2 Tensor) Tensor
MulmatID(ctx Context, t2, ids Tensor) Tensor
Softmax(ctx Context) Tensor Softmax(ctx Context) Tensor
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
@@ -150,6 +151,7 @@ type Tensor interface {
Tanh(ctx Context) Tensor Tanh(ctx Context) Tensor
GELU(ctx Context) Tensor GELU(ctx Context) Tensor
SILU(ctx Context) Tensor SILU(ctx Context) Tensor
Sigmoid(ctx Context) Tensor
Reshape(ctx Context, shape ...int) Tensor Reshape(ctx Context, shape ...int) Tensor
View(ctx Context, offset int, shape ...int) Tensor View(ctx Context, offset int, shape ...int) Tensor
@@ -168,6 +170,8 @@ type Tensor interface {
Rows(ctx Context, t2 Tensor) Tensor Rows(ctx Context, t2 Tensor) Tensor
Copy(ctx Context, t2 Tensor) Tensor Copy(ctx Context, t2 Tensor) Tensor
Duplicate(ctx Context) Tensor Duplicate(ctx Context) Tensor
TopK(ctx Context, k int) Tensor
} }
// ScaledDotProductAttention implements a fused attention // ScaledDotProductAttention implements a fused attention

View File

@@ -884,17 +884,32 @@ func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
} }
} }
func (t *Tensor) MulmatID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_mul_mat_id(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, ids.(*Tensor).t),
}
}
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor { func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))
if b != nil { if w != nil {
tt = tt.Add(ctx, b) tt = C.ggml_mul(ctx.(*Context).ctx, tt, w.(*Tensor).t)
if b != nil {
tt = C.ggml_add(ctx.(*Context).ctx, tt, b.(*Tensor).t)
}
} }
return tt return &Tensor{b: t.b, t: tt}
} }
func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor { func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) tt := C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))
if w != nil {
tt = C.ggml_mul(ctx.(*Context).ctx, tt, w.(*Tensor).t)
}
return &Tensor{b: t.b, t: tt}
} }
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor { func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
@@ -995,6 +1010,13 @@ func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
} }
} }
func (t *Tensor) Sigmoid(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_sigmoid_inplace(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor { func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
if len(shape) != 4 { if len(shape) != 4 {
panic("expected 4 dimensions") panic("expected 4 dimensions")
@@ -1158,3 +1180,10 @@ func (t *Tensor) Duplicate(ctx ml.Context) ml.Tensor {
t: C.ggml_dup(ctx.(*Context).ctx, t.t), t: C.ggml_dup(ctx.(*Context).ctx, t.t),
} }
} }
func (t *Tensor) TopK(ctx ml.Context, k int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_top_k(ctx.(*Context).ctx, t.t, C.int(k)),
}
}

View File

@@ -57,26 +57,20 @@ var OnceLoad = sync.OnceFunc(func() {
exe = "." exe = "."
} }
// PATH, LD_LIBRARY_PATH, and DYLD_LIBRARY_PATH are often var value string
// set by the parent process, however, use a default value
// if the environment variable is not set.
var name, value string
switch runtime.GOOS { switch runtime.GOOS {
case "darwin": case "darwin":
// On macOS, DYLD_LIBRARY_PATH is often not set, so
// we use the directory of the executable as the default.
name = "DYLD_LIBRARY_PATH"
value = filepath.Dir(exe) value = filepath.Dir(exe)
case "windows": case "windows":
name = "PATH"
value = filepath.Join(filepath.Dir(exe), "lib", "ollama") value = filepath.Join(filepath.Dir(exe), "lib", "ollama")
default: default:
name = "LD_LIBRARY_PATH"
value = filepath.Join(filepath.Dir(exe), "..", "lib", "ollama") value = filepath.Join(filepath.Dir(exe), "..", "lib", "ollama")
} }
paths, ok := os.LookupEnv(name) // Avoid potentially loading incompatible GGML libraries
paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH")
if !ok { if !ok {
slog.Debug("OLLAMA_LIBRARY_PATH not set, falling back to default", "search", value)
paths = value paths = value
} }

View File

@@ -42,7 +42,7 @@ func New(c fs.Config) (model.Model, error) {
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"), Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Uints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
}, },

View File

@@ -59,7 +59,7 @@ func New(c fs.Config) (model.Model, error) {
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"), Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Uints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
EOS: int32(1), EOS: int32(1),

View File

@@ -49,7 +49,7 @@ func newTextModel(c fs.Config) *TextModel {
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"), Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Uints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
}, },

View File

@@ -41,7 +41,7 @@ func New(c fs.Config) (model.Model, error) {
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"), Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),

View File

@@ -0,0 +1,189 @@
package llama4
import (
"bytes"
"image"
"slices"
"sync"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.BytePairEncoding
ImageProcessor
*VisionModel `gguf:"v,vision"`
*Projector `gguf:"mm"`
*TextModel
}
type Projector struct {
Linear1 *nn.Linear `gguf:"linear_1"`
}
func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
return p.Linear1.Forward(ctx, visionOutputs)
}
func New(c fs.Config) (model.Model, error) {
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer",
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
},
),
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),
}
m.Cache = kvcache.NewWrapperCache(
kvcache.NewChunkedAttentionCache(int32(c.Uint("attention.chunk_size", 8192)), m.Shift),
kvcache.NewCausalCache(m.Shift),
)
return &m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
if len(m.VisionModel.Layers) < 1 {
return nil, model.ErrNoVisionModel
}
img, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
pixelsLocal, pixelsGlobal, size, err := m.ProcessImage(img)
if err != nil {
return nil, err
}
tilesLocal, err := ctx.Input().FromFloatSlice(pixelsLocal, size.X, size.Y, m.numChannels)
if err != nil {
return nil, err
}
ratioW, ratioH := size.X/m.imageSize, size.Y/m.imageSize
tilesLocal = tilesLocal.Reshape(ctx, size.X/ratioW, ratioW, size.Y, m.numChannels).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
tilesLocal = tilesLocal.Reshape(ctx, size.X/ratioW*size.Y/ratioH, ratioH, ratioW, m.numChannels).Permute(ctx, 0, 3, 2, 1).Contiguous(ctx)
tilesLocal = tilesLocal.Reshape(ctx, size.X/ratioW, size.Y/ratioH, m.numChannels, ratioH*ratioW)
pixelValues := tilesLocal
if len(pixelsGlobal) > 0 {
tilesGlobal, err := ctx.Input().FromFloatSlice(pixelsGlobal, m.imageSize, m.imageSize, m.numChannels)
if err != nil {
return nil, err
}
pixelValues = pixelValues.Concat(ctx, tilesGlobal, 3)
}
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0), visionOutputs.Dim(1)*visionOutputs.Dim(2)*visionOutputs.Dim(3))
projectedOutputs := m.Projector.Forward(ctx, visionOutputs)
return &chunks{Model: m, Tensor: projectedOutputs, aspectRatio: image.Point{ratioW, ratioH}}, nil
}
type chunks struct {
*Model
ml.Tensor
aspectRatio image.Point
dataOnce sync.Once
data []float32
}
type chunk struct {
*chunks
s, n int
}
func (r *chunk) floats() []float32 {
r.dataOnce.Do(func() {
temp := r.Backend().NewContext()
defer temp.Close()
temp.Forward(r.Tensor).Compute(r.Tensor)
r.data = r.Floats()
})
return r.data[r.s*r.Dim(0) : (r.s+r.n)*r.Dim(0)]
}
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
for _, inp := range inputs {
if inp.Multimodal == nil {
result = append(result, inp)
continue
}
t := inp.Multimodal.(*chunks)
var imageInputs []input.Input
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_start|>
var offset int
patchesPerChunk := t.Dim(1)
if t.aspectRatio.Y*t.aspectRatio.X > 1 {
patchesPerChunk = t.Dim(1) / (t.aspectRatio.X*t.aspectRatio.Y + 1)
for range t.aspectRatio.Y {
for x := range t.aspectRatio.X {
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: &chunk{t, offset, patchesPerChunk}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
if x < t.aspectRatio.X-1 {
imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|>
}
offset += patchesPerChunk
}
imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|>
}
}
imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|>
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: &chunk{t, offset, patchesPerChunk}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|>
result = append(result, imageInputs...)
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
}
func init() {
model.Register("llama4", New)
}

View File

@@ -0,0 +1,259 @@
package llama4
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model/input"
)
type TextAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
RopeFactors ml.Tensor `gguf:"rope_factors"`
}
func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attentionScales ml.Tensor, cache kvcache.Cache, useRope bool, opts *TextOptions) ml.Tensor {
batchSize, headDim := hiddenStates.Dim(1), cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
if useRope {
query = query.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale)
key = key.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale)
}
if opts.useQKNorm {
query = query.RMSNorm(ctx, nil, opts.eps)
key = key.RMSNorm(ctx, nil, opts.eps)
}
if attentionScales != nil && !useRope {
query = query.Mul(ctx, attentionScales)
}
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
}
type TextMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type TextExperts struct {
Gate ml.Tensor `gguf:"ffn_gate_exps.weight"`
Up ml.Tensor `gguf:"ffn_up_exps.weight"`
Down ml.Tensor `gguf:"ffn_down_exps.weight"`
}
func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor {
experts := routerLogits.TopK(ctx, opts.numExpertsUsed)
scores := routerLogits.Sigmoid(ctx).Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, experts)
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed)
hiddenStates = hiddenStates.Mul(ctx, scores)
upStates := e.Up.MulmatID(ctx, hiddenStates, experts)
gateStates := e.Gate.MulmatID(ctx, hiddenStates, experts)
downStates := e.Down.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
nextStates.Add(ctx, downStates.View(ctx, i*downStates.Stride(1), hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2)))
}
return nextStates
}
// TextSharedExpert is TextMLP with different tensor names
type TextSharedExpert struct {
Gate *nn.Linear `gguf:"ffn_gate_shexp"`
Up *nn.Linear `gguf:"ffn_up_shexp"`
Down *nn.Linear `gguf:"ffn_down_shexp"`
}
func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type TextMOE struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Experts *TextExperts
SharedExpert *TextSharedExpert
}
func (moe *TextMOE) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize)
routerLogits := moe.Router.Forward(ctx, hiddenStates)
sharedStates := moe.SharedExpert.Forward(ctx, hiddenStates, opts)
routedStates := moe.Experts.Forward(ctx, hiddenStates, routerLogits, opts)
return sharedStates.Add(ctx, routedStates)
}
type TextFeedForward interface {
Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor
}
type TextLayer struct {
AttentionNorm *nn.LayerNorm `gguf:"attn_norm"`
Attention *TextAttention
FFNNorm *nn.LayerNorm `gguf:"ffn_norm"`
FeedForward TextFeedForward
}
func (d *TextLayer) Forward(ctx ml.Context, hiddenStates, positions, attentionScales, outputs ml.Tensor, cache kvcache.Cache, useRope bool, opts *TextOptions) ml.Tensor {
residual := hiddenStates
// self attention
hiddenStates = d.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = d.Attention.Forward(ctx, hiddenStates, positions, attentionScales, cache, useRope, opts)
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = d.FFNNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = d.FeedForward.Forward(ctx, hiddenStates, opts)
return residual.Add(ctx, hiddenStates)
}
type TextOptions struct {
hiddenSize int
numHeads, numKVHeads, headDim int
numExperts, numExpertsUsed int
ropeDim int
ropeBase, ropeScale float32
eps float32
interleaveLayerStep int
noRopeInterval int
useQKNorm bool
attentionTemperatureTuning bool
attentionScale float64
attentionFloorScale float64
}
type TextModel struct {
Layers []TextLayer `gguf:"blk"`
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.LayerNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*TextOptions
}
func newTextModel(c fs.Config) *TextModel {
layers := make([]TextLayer, c.Uint("block_count"))
interleaveLayerStep := c.Uint("interleave_moe_layer_step", 1)
for i := range layers {
if (i+1)%int(interleaveLayerStep) == 0 {
layers[i] = TextLayer{FeedForward: &TextMOE{}}
} else {
layers[i] = TextLayer{FeedForward: &TextMLP{}}
}
}
return &TextModel{
Layers: layers,
TextOptions: &TextOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.head_dim", 128)),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
ropeDim: int(c.Uint("rope.dimension_count")),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
eps: c.Float("attention.layer_norm_rms_epsilon"),
interleaveLayerStep: int(c.Uint("interleave_moe_layer_step", 1)),
noRopeInterval: int(c.Uint("no_rope_interval", 4)),
useQKNorm: c.Bool("use_qk_norm", true),
attentionTemperatureTuning: c.Bool("attention.temperature_tuning", true),
attentionScale: float64(c.Float("attention.scale", 0.1)),
attentionFloorScale: float64(c.Float("attention.floor_scale", 8192)),
},
}
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
for _, mi := range batch.Multimodal {
f32s := mi.Multimodal.(*chunk).floats()
img, err := ctx.Input().FromFloatSlice(f32s, len(f32s)/m.hiddenSize, m.hiddenSize)
if err != nil {
panic(err)
}
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
}
var attentionScales ml.Tensor
if m.attentionTemperatureTuning {
scales := make([]float32, len(batch.Positions))
for i, p := range batch.Positions {
scales[i] = float32(math.Log(math.Floor(((float64(p)+1.0)/float64(m.attentionFloorScale))+1.0))*m.attentionScale + 1.0)
}
var err error
attentionScales, err = ctx.Input().FromFloatSlice(scales, 1, 1, len(scales))
if err != nil {
panic(err)
}
}
for i, layer := range m.Layers {
cache.SetLayer(i)
wc := cache.(*kvcache.WrapperCache)
wc.SetLayerType(1)
useChunkedAttention := (i+1)%m.noRopeInterval != 0
if useChunkedAttention {
wc.SetLayerType(0)
}
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, attentionScales, lastLayerOutputs, cache, useChunkedAttention, m.TextOptions)
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates)
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Layers[layer].Attention.RopeFactors, uint32(0), uint32(m.ropeDim), m.ropeBase, m.ropeScale), nil
}

View File

@@ -0,0 +1,256 @@
package llama4
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
type VisionAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
// applyVisionRotaryEmbedding applies 2D rotary embedding to the input tensor.
// This is equivalent to the Pytorch implmentation using half rotations:
//
// cos, sin = torch.cos(freqs), torch.sin(freqs)
// cos = cos.unsqueeze(-1)
// sin = sin.unsqueeze(-1)
// t = t.reshape(*t.shape[:-1], -1, 2)
// t_out = (t * cos) + (_rotate_half(t) * sin)
// t_out = t_out.flatten(3)
//
// Which is equivalent to the Pytorch implementation using complex numbers:
//
// t_ = torch.view_as_complex(t.float().reshape(*t.shape[:-1], -1, 2))
// freqs_ci = reshape_for_broadcast(freqs_ci=freq_cis, t=t_) # freqs_ci[:,:,None,:]
// freqs_ci = freqs_ci.to(t_.device)
// t_out = torch.view_as_real(t_ * freqs_ci).flatten(3)
//
// Due to the 1) the dimensional and 2) the datatype limitations of current backends,
// we need to use a different approach to achieve the same result.
func applyVisionRotaryEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
width, height, channels, tiles := t.Dim(0), t.Dim(1), t.Dim(2), t.Dim(3)
t = t.Reshape(ctx, 2, t.Dim(0)/2, t.Dim(1)*t.Dim(2)*t.Dim(3))
// t1 = t[..., 0::2]
t1 := t.View(ctx, 0, 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx)
t1 = t1.Reshape(ctx, width/2, height, channels, tiles)
// t2 = t[..., 1::2]
t2 := t.View(ctx, t.Stride(0), 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx)
t2 = t2.Reshape(ctx, width/2, height, channels, tiles)
// cos_out = torch.stack((t1 * cos, t2 * cos), dim=-1)
cosOut := t1.Mul(ctx, cos).Concat(ctx, t2.Mul(ctx, cos), 0)
cosOut = cosOut.Reshape(ctx, cosOut.Dim(0)/2, 2, cosOut.Dim(1)*cosOut.Dim(2)*cosOut.Dim(3))
cosOut = cosOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
cosOut = cosOut.Reshape(ctx, width, height, channels, tiles)
// sin_out = torch.stack((-t2 * sin, t1 * sin), dim=-1)
sinOut := t2.Neg(ctx).Mul(ctx, sin).Concat(ctx, t1.Mul(ctx, sin), 0)
sinOut = sinOut.Reshape(ctx, sinOut.Dim(0)/2, 2, sinOut.Dim(1)*sinOut.Dim(2)*sinOut.Dim(3))
sinOut = sinOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
sinOut = sinOut.Reshape(ctx, width, height, channels, tiles)
return cosOut.Add(ctx, sinOut)
}
func (sa *VisionAttention) Forward(ctx ml.Context, hiddenState, cos, sin ml.Tensor, opts *VisionOptions) ml.Tensor {
headDim := opts.hiddenSize / opts.numHeads
query := sa.Query.Forward(ctx, hiddenState)
key := sa.Key.Forward(ctx, hiddenState)
value := sa.Value.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), query.Dim(2))
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), key.Dim(2))
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), value.Dim(2))
query = applyVisionRotaryEmbedding(ctx, query, cos, sin)
key = applyVisionRotaryEmbedding(ctx, key, cos, sin)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), attention.Dim(3))
return sa.Output.Forward(ctx, attention)
}
type VisionMLP struct {
FC1 *nn.Linear `gguf:"fc1"`
FC2 *nn.Linear `gguf:"fc2"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor {
hiddenStates = mlp.FC1.Forward(ctx, hiddenStates).GELU(ctx)
hiddenStates = mlp.FC2.Forward(ctx, hiddenStates)
return hiddenStates
}
type VisionLayer struct {
InputLayerNorm *nn.LayerNorm `gguf:"attn_norm"`
*VisionAttention
PostAttentionNorm *nn.LayerNorm `gguf:"ffn_norm"`
*VisionMLP `gguf:"mlp"`
}
func (e *VisionLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionOptions) ml.Tensor {
residual := hiddenStates
// self attention
hiddenStates = e.InputLayerNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.VisionAttention.Forward(ctx, hiddenStates, cos, sin, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
// MLP
residual = hiddenStates
hiddenStates = e.PostAttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.VisionMLP.Forward(ctx, hiddenStates, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
return hiddenStates
}
type VisionAdapter struct {
FC1 *nn.Linear `gguf:"mlp.fc1"`
FC2 *nn.Linear `gguf:"mlp.fc2"`
}
func (a *VisionAdapter) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor {
patches := hiddenStates.Dim(1)
patchSize := int(math.Sqrt(float64(patches)))
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), patchSize, patchSize, hiddenStates.Dim(2))
channels, width, height, tiles := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2), hiddenStates.Dim(3)
channels, width = int(float32(channels)/opts.pixelShuffleRatio), int(float32(width)*opts.pixelShuffleRatio)
hiddenStates = hiddenStates.Reshape(ctx, channels, width, height, tiles)
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
channels, height = int(float32(channels)/opts.pixelShuffleRatio), int(float32(height)*opts.pixelShuffleRatio)
hiddenStates = hiddenStates.Reshape(ctx, channels, width, height, tiles)
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
hiddenStates = hiddenStates.Reshape(ctx, channels, width*height, tiles)
hiddenStates = a.FC1.Forward(ctx, hiddenStates).GELU(ctx)
hiddenStates = a.FC2.Forward(ctx, hiddenStates).GELU(ctx)
return hiddenStates
}
type VisionOptions struct {
hiddenSize, numHeads int
imageSize, patchSize int
ropeTheta float32
eps float32
pixelShuffleRatio float32
}
type PatchEmbedding struct {
*nn.Linear
}
func (p *PatchEmbedding) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor {
kernel := ctx.Input().Empty(ml.DTypeF32, opts.patchSize, opts.patchSize, hiddenStates.Dim(2))
hiddenStates = kernel.IM2Col(ctx, hiddenStates, opts.patchSize, opts.patchSize, 0, 0, 1, 1)
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), hiddenStates.Dim(1)*hiddenStates.Dim(2), hiddenStates.Dim(3))
return p.Linear.Forward(ctx, hiddenStates)
}
type VisionModel struct {
Layers []VisionLayer `gguf:"blk"`
*PatchEmbedding `gguf:"patch_embedding"`
ClassEmbedding ml.Tensor `gguf:"class_embedding"`
PositionalEmbedding ml.Tensor `gguf:"positional_embedding_vlm"`
LayerNormPre *nn.LayerNorm `gguf:"layernorm_pre"`
LayerNormPost *nn.LayerNorm `gguf:"layernorm_post"`
*VisionAdapter `gguf:"vision_adapter"`
*VisionOptions
}
func newVisionModel(c fs.Config) *VisionModel {
return &VisionModel{
Layers: make([]VisionLayer, c.Uint("vision.block_count")),
VisionOptions: &VisionOptions{
hiddenSize: int(c.Uint("vision.embedding_length")),
numHeads: int(c.Uint("vision.attention.head_count")),
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),
ropeTheta: float32(c.Float("vision.rope.freq_base")),
eps: c.Float("vision.layer_norm_epsilon"),
pixelShuffleRatio: float32(c.Float("vision.pixel_shuffle_ratio")),
},
}
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionOptions)
hiddenStates = hiddenStates.Concat(ctx, m.ClassEmbedding.Repeat(ctx, 2, hiddenStates.Dim(2)), 1)
hiddenStates = hiddenStates.Add(ctx, m.PositionalEmbedding)
hiddenStates = m.LayerNormPre.Forward(ctx, hiddenStates, m.eps)
cos, sin := m.rotaryEmbedding(ctx)
for _, layer := range m.Layers {
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionOptions)
}
hiddenStates = m.LayerNormPost.Forward(ctx, hiddenStates, m.eps)
hiddenStates = hiddenStates.Unpad(ctx, 0, 1, 0, 0)
hiddenStates = m.VisionAdapter.Forward(ctx, hiddenStates, m.VisionOptions)
return hiddenStates
}
// floorDiv is a helper function to perform floor division. This mimics PyTorch's div(round_mode='floor') function
// which in turn mimics Python's // operator.
func floorDiv[T int | int16 | int32 | int64 | uint | uint16 | uint32 | uint64](a, b T) T {
if b == 0 {
panic("division by zero")
}
if (a >= 0 && b > 0) || (a <= 0 && b < 0) || a%b == 0 {
return a / b
}
return a/b - 1
}
func (m *VisionModel) rotaryEmbedding(ctx ml.Context) (ml.Tensor, ml.Tensor) {
patchesPerSide := m.imageSize / m.patchSize
numPatches := patchesPerSide*patchesPerSide + 1
headDim := m.hiddenSize / m.numHeads
freqDim := headDim / 2
freqs := make([]float32, numPatches*freqDim)
for i := range numPatches - 1 {
for j := 0; j < freqDim; j += 2 {
positionX := i*freqDim/2 + j/2
positionY := (i+numPatches)*freqDim/2 + j/2
ropeFreq := math.Pow(float64(m.ropeTheta), float64(j)*2/float64(headDim))
freqs[positionX] = float32(float64(1+i-floorDiv(i, patchesPerSide)*patchesPerSide) / ropeFreq)
freqs[positionY] = float32(float64(1+floorDiv(i, patchesPerSide)) / ropeFreq)
}
}
ropeFreqs, err := ctx.Input().FromFloatSlice(freqs, freqDim/2, numPatches, 2)
if err != nil {
panic(err)
}
ropeFreqs = ropeFreqs.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
ropeFreqs = ropeFreqs.Reshape(ctx, freqDim, 1, numPatches)
return ropeFreqs.Cos(ctx), ropeFreqs.Sin(ctx)
}

View File

@@ -0,0 +1,167 @@
package llama4
import (
"cmp"
"image"
"math"
"slices"
"sort"
"golang.org/x/image/draw"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/model/imageproc"
)
type ImageProcessor struct {
imageSize, patchSize, numChannels, maxUpscalingSize int
}
func newImageProcessor(c fs.Config) ImageProcessor {
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),
numChannels: int(c.Uint("vision.num_channels", 3)),
maxUpscalingSize: int(c.Uint("vision.max_upscaling_size", 448)),
}
}
func factors(n int) []int {
var result []int
seen := make(map[int]bool)
for i := 1; i <= n/2; i++ {
if n%i == 0 && !seen[i] {
result = append(result, i)
seen[i] = true
}
}
result = append(result, n)
sort.Ints(result)
return result
}
func (p ImageProcessor) supportedResolutions() []image.Point {
var resolutions []image.Point
aspectMap := make(map[float64][]image.Point)
for i := p.patchSize; i >= 1; i-- {
for _, f := range factors(i) {
x := f
y := i / f
k := float64(y) / float64(x)
aspectMap[k] = append(aspectMap[k], image.Point{x, y})
}
}
for _, v := range aspectMap {
for _, i := range v {
resolutions = append(resolutions, image.Point{i.X * p.imageSize, i.Y * p.imageSize})
}
}
return resolutions
}
func (p ImageProcessor) bestResolution(img image.Point, possibleResolutions []image.Point, resizeToMaxCanvas bool) image.Point {
w, h := img.X, img.Y
scales := make([]float64, len(possibleResolutions))
for i, res := range possibleResolutions {
scaleW := float64(res.X) / float64(w)
scaleH := float64(res.Y) / float64(h)
scale := math.Min(scaleW, scaleH)
scales[i] = scale
}
minAboveOne := func(scales []float64) (float64, bool) {
min := math.MaxFloat64
found := false
for _, s := range scales {
if s >= 1.0 && s < min {
min = s
found = true
}
}
return min, found
}
bestScale, ok := minAboveOne(scales)
if resizeToMaxCanvas || !ok {
bestScale = slices.Max(scales)
}
var bestOptions []image.Point
for i, scale := range scales {
if math.Abs(scale-bestScale) < 1e-6 {
bestOptions = append(bestOptions, possibleResolutions[i])
}
}
var chosenResolution image.Point
if len(bestOptions) > 1 {
chosenResolution = slices.MinFunc(bestOptions, func(a, b image.Point) int {
return cmp.Compare(a.X*a.Y, b.X*b.Y)
})
} else {
chosenResolution = bestOptions[0]
}
return chosenResolution
}
func (p ImageProcessor) maxResolution(imageRes, targetRes image.Point) image.Point {
scaleW := float64(targetRes.X) / float64(imageRes.X)
scaleH := float64(targetRes.Y) / float64(imageRes.Y)
var newRes image.Point
if scaleW < scaleH {
newRes = image.Point{
targetRes.X,
int(math.Min(math.Floor(float64(imageRes.Y)*scaleW), float64(targetRes.Y))),
}
} else {
newRes = image.Point{
int(math.Min(math.Floor(float64(imageRes.X)*scaleH), float64(targetRes.X))),
targetRes.Y,
}
}
return newRes
}
func (p ImageProcessor) pad(src image.Image, outputSize image.Point) image.Image {
dst := image.NewRGBA(image.Rect(0, 0, outputSize.X, outputSize.Y))
draw.Draw(dst, src.Bounds(), src, image.Point{}, draw.Over)
return dst
}
func (p ImageProcessor) ProcessImage(img image.Image) (pixelsLocal, pixelsGlobal []float32, targetSize image.Point, _ error) {
img = imageproc.Composite(img)
targetSize = p.bestResolution(img.Bounds().Max, p.supportedResolutions(), false)
targetSizeWithoutDistortion := targetSize
if p.maxUpscalingSize > 0 {
targetSizeWithoutDistortion = p.maxResolution(img.Bounds().Max, targetSize)
targetSizeWithoutDistortion.X = min(max(img.Bounds().Max.X, p.maxUpscalingSize), targetSize.X)
targetSizeWithoutDistortion.Y = min(max(img.Bounds().Max.Y, p.maxUpscalingSize), targetSize.Y)
}
newSizeWithoutDistortion := p.maxResolution(img.Bounds().Max, targetSizeWithoutDistortion)
padded := p.pad(imageproc.Resize(img, newSizeWithoutDistortion, imageproc.ResizeBilinear), targetSize)
pixelsLocal = imageproc.Normalize(padded, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD, true, true)
if targetSize.X/p.imageSize*targetSize.Y/p.imageSize > 1 {
padded := imageproc.Resize(img, image.Point{p.imageSize, p.imageSize}, imageproc.ResizeBilinear)
pixelsGlobal = imageproc.Normalize(padded, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD, true, true)
}
return pixelsLocal, pixelsGlobal, targetSize, nil
}

View File

@@ -0,0 +1,300 @@
package llama4
import (
"cmp"
"image"
"image/color"
"reflect"
"slices"
"testing"
gocmp "github.com/google/go-cmp/cmp"
)
func TestFactors(t *testing.T) {
tests := []struct {
name string
input int
expected []int
}{
{
name: "factors of 1",
input: 1,
expected: []int{1},
},
{
name: "factors of 2",
input: 2,
expected: []int{1, 2},
},
{
name: "factors of 6",
input: 6,
expected: []int{1, 2, 3, 6},
},
{
name: "factors of 28",
input: 28,
expected: []int{1, 2, 4, 7, 14, 28},
},
{
name: "factors of 49",
input: 49,
expected: []int{1, 7, 49},
},
{
name: "factors of 97 (prime)",
input: 97,
expected: []int{1, 97},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := factors(tt.input)
if !reflect.DeepEqual(actual, tt.expected) {
t.Errorf("factors(%d) = %v; want %v", tt.input, actual, tt.expected)
}
})
}
}
func TestSupportedResolutions(t *testing.T) {
expectedResolutions := []image.Point{
{X: 3360, Y: 336},
{X: 672, Y: 2688},
{X: 336, Y: 1344},
{X: 336, Y: 4032},
{X: 1008, Y: 1344},
{X: 1344, Y: 1008},
{X: 336, Y: 1680},
{X: 1680, Y: 336},
{X: 336, Y: 5040},
{X: 4032, Y: 336},
{X: 2352, Y: 336},
{X: 2688, Y: 672},
{X: 1344, Y: 336},
{X: 5376, Y: 336},
{X: 2352, Y: 672},
{X: 672, Y: 1008},
{X: 1008, Y: 672},
{X: 336, Y: 5376},
{X: 1680, Y: 1008},
{X: 5040, Y: 336},
{X: 336, Y: 3024},
{X: 3024, Y: 336},
{X: 336, Y: 2688},
{X: 672, Y: 1344},
{X: 336, Y: 672},
{X: 336, Y: 2352},
{X: 2016, Y: 672},
{X: 1008, Y: 336},
{X: 336, Y: 3360},
{X: 336, Y: 4368},
{X: 1008, Y: 1680},
{X: 336, Y: 4704},
{X: 4704, Y: 336},
{X: 1344, Y: 672},
{X: 672, Y: 336},
{X: 2688, Y: 336},
{X: 3696, Y: 336},
{X: 2016, Y: 336},
{X: 1344, Y: 1344},
{X: 1008, Y: 1008},
{X: 672, Y: 672},
{X: 336, Y: 336},
{X: 4368, Y: 336},
{X: 672, Y: 2016},
{X: 336, Y: 1008},
{X: 336, Y: 3696},
{X: 672, Y: 1680},
{X: 1680, Y: 672},
{X: 336, Y: 2016},
{X: 672, Y: 2352},
}
sortResolutionFunc := func(a, b image.Point) int {
return cmp.Or(cmp.Compare(a.X, b.X), cmp.Compare(a.Y, b.Y))
}
slices.SortStableFunc(expectedResolutions, sortResolutionFunc)
imgProc := ImageProcessor{
imageSize: 336,
patchSize: 16,
numChannels: 3,
maxUpscalingSize: 448,
}
actualResolutions := imgProc.supportedResolutions()
slices.SortStableFunc(actualResolutions, sortResolutionFunc)
if diff := gocmp.Diff(expectedResolutions, actualResolutions); diff != "" {
t.Errorf("supportedResolutions() mismatch (-want +got):\n%s", diff)
}
}
func TestBestResolution(t *testing.T) {
tests := []struct {
name string
size image.Point
resolutions []image.Point
max bool
expected image.Point
}{
{
"normal",
image.Point{800, 600},
[]image.Point{
{300, 200},
{640, 480},
{800, 600},
{1024, 768},
{1600, 1200},
},
false,
image.Point{800, 600},
},
{
"max",
image.Point{800, 600},
[]image.Point{
{300, 200},
{640, 480},
{800, 600},
{1024, 768},
{1600, 1200},
},
true,
image.Point{1600, 1200},
},
{
"mid",
image.Point{1000, 700},
[]image.Point{
{300, 200},
{640, 480},
{800, 600},
{1024, 768},
{1600, 1200},
},
false,
image.Point{1024, 768},
},
{
"smol",
image.Point{100, 100},
[]image.Point{
{300, 200},
{640, 480},
{800, 600},
{1024, 768},
{1600, 1200},
},
false,
image.Point{300, 200},
},
{
"huge",
image.Point{10000, 10000},
[]image.Point{
{300, 200},
{640, 480},
{800, 600},
{1024, 768},
{1600, 1200},
},
false,
image.Point{1600, 1200},
},
}
p := ImageProcessor{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := p.bestResolution(tt.size, tt.resolutions, tt.max)
if diff := gocmp.Diff(tt.expected, actual); diff != "" {
t.Errorf("best resolution mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestMaxResolution(t *testing.T) {
tests := []struct {
name string
origRes image.Point
targetRes image.Point
expected image.Point
}{
{
"normal",
image.Point{800, 600},
image.Point{800, 600},
image.Point{800, 600},
},
{
"skew",
image.Point{800, 600},
image.Point{1100, 700},
image.Point{933, 700},
},
}
p := ImageProcessor{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := p.maxResolution(tt.origRes, tt.targetRes)
if !reflect.DeepEqual(actual, tt.expected) {
t.Errorf("max resolution; got %v want %v", actual, tt.expected)
}
})
}
}
func TestProcessImage(t *testing.T) {
imgProc := ImageProcessor{
imageSize: 336,
patchSize: 16,
numChannels: 3,
maxUpscalingSize: 448,
}
generateImage := func(seed int) image.Image {
width, height := 20, 10
img := image.NewRGBA(image.Rect(0, 0, width, height))
for x := range width {
// Use the seed to vary color generation
r := uint8((seed + x*11) % 256)
g := uint8((seed + x*17) % 256)
b := uint8((seed + x*23) % 256)
c := color.RGBA{R: r, G: g, B: b, A: 255}
for y := range height {
img.Set(x, y, c)
}
}
return img
}
pixelsLocal, pixelsGlobal, targetSize, err := imgProc.ProcessImage(generateImage(12))
if err != nil {
t.Error(err)
}
if n := len(pixelsLocal); n != 336*336*3 {
t.Errorf("unexpected size of f32s: %d", n)
}
if n := len(pixelsGlobal); n > 0 {
t.Errorf("unexpected size of f32s: %d", n)
}
if !targetSize.Eq(image.Point{336, 336}) {
t.Errorf("unexpected target size: %v", targetSize)
}
}

View File

@@ -152,7 +152,7 @@ func NewTextModel(c fs.Config) (*TextModel, error) {
c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`), c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"), Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id", 1)), BOS: int32(c.Uint("tokenizer.ggml.bos_token_id", 1)),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),

View File

@@ -43,7 +43,7 @@ func New(c fs.Config) (model.Model, error) {
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"), Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),

View File

@@ -177,7 +177,7 @@ type TextDecoder struct {
func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
for i, layer := range d.Layers { for i, layer := range d.Layers {
layerType := selfAttentionLayer layerType := selfAttentionLayer
if slices.Contains(opts.crossAttentionLayers, uint32(i)) { if slices.Contains(opts.crossAttentionLayers, int32(i)) {
layerType = crossAttentionLayer layerType = crossAttentionLayer
} }
@@ -202,7 +202,7 @@ type TextModelOptions struct {
eps, ropeBase, ropeScale float32 eps, ropeBase, ropeScale float32
ropeDim uint32 ropeDim uint32
crossAttentionLayers []uint32 crossAttentionLayers []int32
} }
type TextModel struct { type TextModel struct {
@@ -225,7 +225,7 @@ func newTextModel(c fs.Config) *TextModel {
var decoderLayers []TextDecoderLayer var decoderLayers []TextDecoderLayer
for i := range c.Uint("block_count") { for i := range c.Uint("block_count") {
var textDecoderLayer TextDecoderLayer var textDecoderLayer TextDecoderLayer
if slices.Contains(c.Uints("attention.cross_attention_layers"), i) { if slices.Contains(c.Ints("attention.cross_attention_layers"), int32(i)) {
textDecoderLayer = &TextCrossAttentionDecoderLayer{} textDecoderLayer = &TextCrossAttentionDecoderLayer{}
} else { } else {
textDecoderLayer = &TextSelfAttentionDecoderLayer{} textDecoderLayer = &TextSelfAttentionDecoderLayer{}
@@ -244,7 +244,7 @@ func newTextModel(c fs.Config) *TextModel {
ropeBase: c.Float("rope.freq_base"), ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1), ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count"), ropeDim: c.Uint("rope.dimension_count"),
crossAttentionLayers: c.Uints("attention.cross_attention_layers"), crossAttentionLayers: c.Ints("attention.cross_attention_layers"),
}, },
} }
} }

View File

@@ -96,10 +96,10 @@ type VisionEncoder struct {
Layers []VisionEncoderLayer Layers []VisionEncoderLayer
} }
func (e *VisionEncoder) Forward(ctx ml.Context, hiddenState ml.Tensor, intermediateLayersIndices []uint32, opts *VisionModelOptions) (ml.Tensor, []ml.Tensor) { func (e *VisionEncoder) Forward(ctx ml.Context, hiddenState ml.Tensor, intermediateLayersIndices []int32, opts *VisionModelOptions) (ml.Tensor, []ml.Tensor) {
var intermediateHiddenStates []ml.Tensor var intermediateHiddenStates []ml.Tensor
for i, layer := range e.Layers { for i, layer := range e.Layers {
if slices.Contains(intermediateLayersIndices, uint32(i)) { if slices.Contains(intermediateLayersIndices, int32(i)) {
intermediateHiddenStates = append(intermediateHiddenStates, hiddenState.Reshape(ctx, append([]int{1}, hiddenState.Shape()...)...)) intermediateHiddenStates = append(intermediateHiddenStates, hiddenState.Reshape(ctx, append([]int{1}, hiddenState.Shape()...)...))
} }
@@ -154,7 +154,7 @@ type VisionModelOptions struct {
imageSize, patchSize int imageSize, patchSize int
eps float32 eps float32
intermediateLayersIndices []uint32 intermediateLayersIndices []int32
} }
type VisionModel struct { type VisionModel struct {
@@ -229,7 +229,7 @@ func newVisionModel(c fs.Config) *VisionModel {
eps: c.Float("vision.attention.layer_norm_epsilon"), eps: c.Float("vision.attention.layer_norm_epsilon"),
intermediateLayersIndices: c.Uints("vision.intermediate_layers_indices"), intermediateLayersIndices: c.Ints("vision.intermediate_layers_indices"),
}, },
} }
} }

View File

@@ -4,6 +4,7 @@ import (
_ "github.com/ollama/ollama/model/models/gemma2" _ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3" _ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/llama" _ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/llama4"
_ "github.com/ollama/ollama/model/models/mistral3" _ "github.com/ollama/ollama/model/models/mistral3"
_ "github.com/ollama/ollama/model/models/mllama" _ "github.com/ollama/ollama/model/models/mllama"
) )

View File

@@ -37,7 +37,7 @@ type TextProcessor interface {
type Vocabulary struct { type Vocabulary struct {
Values []string Values []string
Types []uint32 Types []int32
Scores []float32 Scores []float32
Merges []string Merges []string

View File

@@ -35,9 +35,9 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
sentencepiece.ModelProto_SentencePiece_CONTROL, sentencepiece.ModelProto_SentencePiece_CONTROL,
sentencepiece.ModelProto_SentencePiece_UNUSED, sentencepiece.ModelProto_SentencePiece_UNUSED,
sentencepiece.ModelProto_SentencePiece_BYTE: sentencepiece.ModelProto_SentencePiece_BYTE:
v.Types = append(v.Types, uint32(t)) v.Types = append(v.Types, int32(t))
default: default:
tt := uint32(sentencepiece.ModelProto_SentencePiece_NORMAL) tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
// todo parse the special tokens file // todo parse the special tokens file
// - this will roundtrip correctly but the <start_of_turn> and // - this will roundtrip correctly but the <start_of_turn> and
// <end_of_turn> tokens aren't processed // <end_of_turn> tokens aren't processed
@@ -124,7 +124,7 @@ func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
"<0xC3>", "<0xC3>",
"<0xA3>", "<0xA3>",
}, },
Types: []uint32{ Types: []int32{
TOKEN_TYPE_NORMAL, TOKEN_TYPE_NORMAL,
TOKEN_TYPE_BYTE, TOKEN_TYPE_BYTE,
TOKEN_TYPE_BYTE, TOKEN_TYPE_BYTE,

View File

@@ -28,7 +28,7 @@ func llama(t testing.TB) BytePairEncoding {
t.Fatal(err) t.Fatal(err)
} }
types := make([]uint32, len(vocab)) types := make([]int32, len(vocab))
tokens := make([]string, len(vocab)) tokens := make([]string, len(vocab))
for token, id := range vocab { for token, id := range vocab {
tokens[id] = token tokens[id] = token

View File

@@ -64,7 +64,7 @@ func formatDuration(d time.Duration) string {
func (b *Bar) String() string { func (b *Bar) String() string {
termWidth, _, err := term.GetSize(int(os.Stderr.Fd())) termWidth, _, err := term.GetSize(int(os.Stderr.Fd()))
if err != nil { if err != nil {
termWidth = defaultTermWidth termWidth = 80
} }
var pre strings.Builder var pre strings.Builder

View File

@@ -4,16 +4,8 @@ import (
"bufio" "bufio"
"fmt" "fmt"
"io" "io"
"os"
"sync" "sync"
"time" "time"
"golang.org/x/term"
)
const (
defaultTermWidth = 80
defaultTermHeight = 24
) )
type State interface { type State interface {
@@ -91,11 +83,6 @@ func (p *Progress) Add(key string, state State) {
} }
func (p *Progress) render() { func (p *Progress) render() {
_, termHeight, err := term.GetSize(int(os.Stderr.Fd()))
if err != nil {
termHeight = defaultTermHeight
}
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()
@@ -115,9 +102,8 @@ func (p *Progress) render() {
fmt.Fprint(p.w, "\033[1G") fmt.Fprint(p.w, "\033[1G")
// render progress lines // render progress lines
maxHeight := min(len(p.states), termHeight) for i, state := range p.states {
for i := len(p.states) - maxHeight; i < len(p.states); i++ { fmt.Fprint(p.w, state.String(), "\033[K")
fmt.Fprint(p.w, p.states[i].String(), "\033[K")
if i < len(p.states)-1 { if i < len(p.states)-1 {
fmt.Fprint(p.w, "\n") fmt.Fprint(p.w, "\n")
} }

View File

@@ -723,7 +723,9 @@ func (m *multiLPath) String() string {
return strings.Join(*m, ", ") return strings.Join(*m, ", ")
} }
func (s *Server) reserveWorstCaseGraph() error { // TODO(jessegross): This is causing tensor allocation failures with large batches when not offloaded
// to the GPU
/*func (s *Server) reserveWorstCaseGraph() error {
ctx := s.model.Backend().NewContext() ctx := s.model.Backend().NewContext()
defer ctx.Close() defer ctx.Close()
@@ -766,7 +768,7 @@ func (s *Server) reserveWorstCaseGraph() error {
} }
return nil return nil
} }*/
func (s *Server) loadModel( func (s *Server) loadModel(
ctx context.Context, ctx context.Context,
@@ -803,10 +805,10 @@ func (s *Server) loadModel(
s.seqs = make([]*Sequence, s.parallel) s.seqs = make([]*Sequence, s.parallel)
s.seqsSem = semaphore.NewWeighted(int64(s.parallel)) s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
err = s.reserveWorstCaseGraph() /*err = s.reserveWorstCaseGraph()
if err != nil { if err != nil {
panic(err) panic(err)
} }*/
s.status = llm.ServerStatusReady s.status = llm.ServerStatusReady
s.ready.Done() s.ready.Done()

View File

@@ -74,7 +74,6 @@ func modelHelper(t testing.TB) model.BytePairEncoding {
t.Fatal(err) t.Fatal(err)
} }
types := make([]uint32, len(vocab))
tokens := make([]string, len(vocab)) tokens := make([]string, len(vocab))
for token, id := range vocab { for token, id := range vocab {
tokens[id] = token tokens[id] = token
@@ -86,7 +85,7 @@ func modelHelper(t testing.TB) model.BytePairEncoding {
``, ``,
&model.Vocabulary{ &model.Vocabulary{
Values: tokens, Values: tokens,
Types: types, Types: make([]int32, len(vocab)),
Merges: merges, Merges: merges,
}, },
) )

View File

@@ -295,7 +295,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
} }
defer bin.Close() defer bin.Close()
f, _, err := ggml.Decode(bin, 0) f, _, err := ggml.Decode(bin, 1024)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -457,7 +457,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
return nil, err return nil, err
} }
f, _, err := ggml.Decode(temp, 0) f, _, err := ggml.Decode(temp, 1024)
if err != nil { if err != nil {
slog.Error(fmt.Sprintf("error decoding ggml: %s\n", err)) slog.Error(fmt.Sprintf("error decoding ggml: %s\n", err))
return nil, err return nil, err
@@ -499,7 +499,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
var offset int64 var offset int64
for offset < stat.Size() { for offset < stat.Size() {
f, n, err := ggml.Decode(blob, 0) f, n, err := ggml.Decode(blob, 1024)
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
break break
} else if err != nil { } else if err != nil {

View File

@@ -75,7 +75,7 @@ func (m *Model) Capabilities() []model.Capability {
if err == nil { if err == nil {
defer r.Close() defer r.Close()
f, _, err := ggml.Decode(r, 0) f, _, err := ggml.Decode(r, 1024)
if err == nil { if err == nil {
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok { if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
capabilities = append(capabilities, model.CapabilityEmbedding) capabilities = append(capabilities, model.CapabilityEmbedding)

View File

@@ -73,8 +73,13 @@ type statusCodeRecorder struct {
func (r *statusCodeRecorder) WriteHeader(status int) { func (r *statusCodeRecorder) WriteHeader(status int) {
if r._status == 0 { if r._status == 0 {
r._status = status r._status = status
r.ResponseWriter.WriteHeader(status)
} }
r.ResponseWriter.WriteHeader(status) }
func (r *statusCodeRecorder) Write(b []byte) (int, error) {
r._status = r.status()
return r.ResponseWriter.Write(b)
} }
var ( var (

View File

@@ -64,7 +64,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
} }
defer blob.Close() defer blob.Close()
f, _, err := ggml.Decode(blob, 0) f, _, err := ggml.Decode(blob, 1024)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -18,6 +18,7 @@ import (
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"regexp"
"slices" "slices"
"strings" "strings"
"syscall" "syscall"
@@ -1512,6 +1513,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
if req.Messages[0].Role != "system" && m.System != "" { if req.Messages[0].Role != "system" && m.System != "" {
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...) msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
} }
msgs = filterThinkTags(msgs, m)
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools) prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
if err != nil { if err != nil {
@@ -1640,3 +1642,23 @@ func handleScheduleError(c *gin.Context, name string, err error) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
} }
} }
var thinkTagRegexp = regexp.MustCompile(`<think>(?s).*?</think>(\n)*`)
func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" {
finalUserIndex := -1
for i, msg := range msgs {
if msg.Role == "user" {
finalUserIndex = i
}
}
for i, msg := range msgs {
if msg.Role == "assistant" && i < finalUserIndex {
msgs[i].Content = thinkTagRegexp.ReplaceAllString(msg.Content, "")
}
}
}
return msgs
}

View File

@@ -299,9 +299,6 @@ func TestGenerateChat(t *testing.T) {
{Role: "user", Content: "Hello!"}, {Role: "user", Content: "Hello!"},
}, },
Stream: &stream, Stream: &stream,
Options: map[string]any{
"num_ctx": 1024,
},
}) })
if w.Code != http.StatusOK { if w.Code != http.StatusOK {
@@ -324,9 +321,6 @@ func TestGenerateChat(t *testing.T) {
{Role: "user", Content: "Hello!"}, {Role: "user", Content: "Hello!"},
}, },
Stream: &stream, Stream: &stream,
Options: map[string]any{
"num_ctx": 1024,
},
}) })
if w.Code != http.StatusOK { if w.Code != http.StatusOK {
@@ -350,9 +344,6 @@ func TestGenerateChat(t *testing.T) {
{Role: "user", Content: "Help me write tests."}, {Role: "user", Content: "Help me write tests."},
}, },
Stream: &stream, Stream: &stream,
Options: map[string]any{
"num_ctx": 1024,
},
}) })
if w.Code != http.StatusOK { if w.Code != http.StatusOK {

View File

@@ -15,6 +15,7 @@ import (
"net/http/httptest" "net/http/httptest"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"sort" "sort"
"strings" "strings"
"testing" "testing"
@@ -746,3 +747,128 @@ func TestNormalize(t *testing.T) {
}) })
} }
} }
func TestFilterThinkTags(t *testing.T) {
type testCase struct {
msgs []api.Message
want []api.Message
model *Model
}
testCases := []testCase{
{
msgs: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
{Role: "user", Content: "What is the answer?"},
},
want: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "abc"},
{Role: "user", Content: "What is the answer?"},
},
model: &Model{
Config: ConfigV2{
ModelFamily: "qwen3",
},
},
},
// with newlines inside the think tag aned newlines after
{
msgs: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "<think>Thinking... \n\nabout \nthe answer</think>\n\nabc\ndef"},
{Role: "user", Content: "What is the answer?"},
},
want: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "abc\ndef"},
{Role: "user", Content: "What is the answer?"},
},
model: &Model{
Config: ConfigV2{
ModelFamily: "qwen3",
},
},
},
// should leave thinking tags if it's after the last user message
{
msgs: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "<think>Thinking...</think>after"},
{Role: "user", Content: "What is the answer?"},
{Role: "assistant", Content: "<think>thinking again</think>hjk"},
{Role: "assistant", Content: "<think>thinking yet again</think>hjk"},
},
want: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "after"},
{Role: "user", Content: "What is the answer?"},
{Role: "assistant", Content: "<think>thinking again</think>hjk"},
{Role: "assistant", Content: "<think>thinking yet again</think>hjk"},
},
model: &Model{
Config: ConfigV2{
ModelFamily: "qwen3",
},
},
},
{
// shouldn't strip anything because the model family isn't one of the hardcoded ones
msgs: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
{Role: "user", Content: "What is the answer?"},
},
want: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
{Role: "user", Content: "What is the answer?"},
},
model: &Model{
Config: ConfigV2{
ModelFamily: "llama3",
},
},
},
{
// deepseek-r1:-prefixed model
msgs: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
{Role: "user", Content: "What is the answer?"},
},
want: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "abc"},
{Role: "user", Content: "What is the answer?"},
},
model: &Model{
Name: "registry.ollama.ai/library/deepseek-r1:latest",
ShortName: "deepseek-r1:7b",
Config: ConfigV2{},
},
},
}
for i, tc := range testCases {
filtered := filterThinkTags(tc.msgs, tc.model)
if !reflect.DeepEqual(filtered, tc.want) {
t.Errorf("messages differ for case %d:", i)
for i := range tc.want {
if i >= len(filtered) {
t.Errorf(" missing message %d: %+v", i, tc.want[i])
continue
}
if !reflect.DeepEqual(filtered[i], tc.want[i]) {
t.Errorf(" message %d:\n want: %+v\n got: %+v", i, tc.want[i], filtered[i])
}
}
if len(filtered) > len(tc.want) {
for i := len(tc.want); i < len(filtered); i++ {
t.Errorf(" extra message %d: %+v", i, filtered[i])
}
}
}
}
}

View File

@@ -81,6 +81,10 @@ func InitScheduler(ctx context.Context) *Scheduler {
// context must be canceled to decrement ref count and release the runner // context must be canceled to decrement ref count and release the runner
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) { func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
if opts.NumCtx < 4 {
opts.NumCtx = 4
}
req := &LlmRequest{ req := &LlmRequest{
ctx: c, ctx: c,
model: model, model: model,
@@ -110,11 +114,6 @@ func (s *Scheduler) Run(ctx context.Context) {
}() }()
} }
const (
defaultContextLength = 4096
smallGpuContextLength = 2048
)
func (s *Scheduler) processPending(ctx context.Context) { func (s *Scheduler) processPending(ctx context.Context) {
for { for {
select { select {
@@ -167,17 +166,6 @@ func (s *Scheduler) processPending(ctx context.Context) {
gpus = s.getGpuFn() gpus = s.getGpuFn()
} }
if pending.origNumCtx == -1 {
if len(gpus) == 1 && gpus[0].Library != "cpu" && gpus[0].TotalMemory <= 4096*1024*1024 {
slog.Info("GPU is small, limiting default context window", "num_ctx", smallGpuContextLength)
pending.opts.NumCtx = smallGpuContextLength
pending.origNumCtx = smallGpuContextLength
} else {
pending.opts.NumCtx = defaultContextLength
pending.origNumCtx = defaultContextLength
}
}
if envconfig.MaxRunners() <= 0 { if envconfig.MaxRunners() <= 0 {
// No user specified MaxRunners, so figure out what automatic setting to use // No user specified MaxRunners, so figure out what automatic setting to use
// If all GPUs have reliable free memory reporting, defaultModelsPerGPU * the number of GPUs // If all GPUs have reliable free memory reporting, defaultModelsPerGPU * the number of GPUs
@@ -453,10 +441,9 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
estimatedVRAM: llama.EstimatedVRAM(), estimatedVRAM: llama.EstimatedVRAM(),
estimatedTotal: llama.EstimatedTotal(), estimatedTotal: llama.EstimatedTotal(),
loading: true, loading: true,
refCount: 1,
} }
runner.numParallel = numParallel runner.numParallel = numParallel
runner.refMu.Lock() runner.refMu.Lock() // hold lock until running or aborted
s.loadedMu.Lock() s.loadedMu.Lock()
s.loaded[req.model.ModelPath] = runner s.loaded[req.model.ModelPath] = runner
@@ -467,13 +454,13 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
defer runner.refMu.Unlock() defer runner.refMu.Unlock()
if err = llama.WaitUntilRunning(req.ctx); err != nil { if err = llama.WaitUntilRunning(req.ctx); err != nil {
slog.Error("error loading llama server", "error", err) slog.Error("error loading llama server", "error", err)
runner.refCount--
req.errCh <- err req.errCh <- err
slog.Debug("triggering expiration for failed load", "model", runner.modelPath) slog.Debug("triggering expiration for failed load", "model", runner.modelPath)
s.expiredCh <- runner s.expiredCh <- runner
return return
} }
slog.Debug("finished setting up runner", "model", req.model.ModelPath) slog.Debug("finished setting up runner", "model", req.model.ModelPath)
runner.refCount++
runner.loading = false runner.loading = false
go func() { go func() {
<-req.ctx.Done() <-req.ctx.Done()
@@ -491,7 +478,12 @@ func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) {
} }
predMap := map[predKey]uint64{} // Sum up the total predicted usage per GPU for all runners predMap := map[predKey]uint64{} // Sum up the total predicted usage per GPU for all runners
s.loadedMu.Lock() s.loadedMu.Lock()
runners := make([]*runnerRef, 0, len(s.loaded))
for _, r := range s.loaded { for _, r := range s.loaded {
runners = append(runners, r)
}
s.loadedMu.Unlock()
for _, r := range runners {
r.refMu.Lock() r.refMu.Lock()
if r.llama != nil { if r.llama != nil {
for _, gpu := range allGpus { for _, gpu := range allGpus {
@@ -502,7 +494,6 @@ func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) {
} }
r.refMu.Unlock() r.refMu.Unlock()
} }
s.loadedMu.Unlock()
// Now that we've summed up all the GPU usage predictions across all the loaded runners, update the gpu list // Now that we've summed up all the GPU usage predictions across all the loaded runners, update the gpu list
for i := range allGpus { for i := range allGpus {
@@ -549,10 +540,8 @@ func (s *Scheduler) filterGPUsWithoutLoadingModels(allGpus discover.GpuInfoList)
// TODO consolidate sched_types.go // TODO consolidate sched_types.go
type runnerRef struct { type runnerRef struct {
refMu sync.Mutex refMu sync.Mutex
// refCond sync.Cond // Signaled on transition from 1 -> 0 refCount
refCount uint // prevent unloading if > 0 refCount uint // prevent unloading if > 0
// unloading bool // set to true when we are trying to unload the runner
llama llm.LlamaServer llama llm.LlamaServer
loading bool // True only during initial load, then false forever loading bool // True only during initial load, then false forever
@@ -823,8 +812,8 @@ func (s *Scheduler) unloadAllRunners() {
func (s *Scheduler) expireRunner(model *Model) { func (s *Scheduler) expireRunner(model *Model) {
s.loadedMu.Lock() s.loadedMu.Lock()
defer s.loadedMu.Unlock()
runner, ok := s.loaded[model.ModelPath] runner, ok := s.loaded[model.ModelPath]
s.loadedMu.Unlock()
if ok { if ok {
runner.refMu.Lock() runner.refMu.Lock()
runner.expiresAt = time.Now() runner.expiresAt = time.Now()

View File

@@ -148,7 +148,6 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
successCh: make(chan *runnerRef, 1), successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1), errCh: make(chan error, 1),
} }
b.req.opts.NumCtx = 4096
b.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}} b.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
return b return b
} }