Compare commits

..

33 Commits

Author SHA1 Message Date
Devon Rifkin
ad3c7c9bda strip out thinking tags in message history for qwen3 & r1 (#10490)
* strip out thinking tags in message history for qwen3 & r1

This is in advance of "proper" support where we'll make reasoning
configurable and we'll parse out thinking/reasoning tags and provide
them to the caller. These models expect there to be no thinking tags in
the message history, so this should improve quality

* parse model names instead of hacky prefix check
2025-04-30 13:57:45 -07:00
Daniel Hiltgen
415c8fcc3d Fix "Stopping..." scheduler hang (#10487)
* Adjust initial scheduler refCount

Ensure we only set the refCount on success

* sched: fix lock order inversion deadlock

Under certain race conditions, there was a scenario where the scheduler would
get into a deadlock while trying to update free space information while a model
was trying to unload.
2025-04-30 11:26:52 -07:00
Daniel Hiltgen
718eda1b3e Narrow set of paths we load GGML from (#10485)
Users may have other incompatible GGML installs on their systems.
This will prevent us from trying to load them from the path.
2025-04-30 11:25:22 -07:00
Shahin R
421b7edeb4 readme: add link to lumina, a lightweight React frontend client (#10378) 2025-04-30 09:50:47 -07:00
batuhankadioglu
7b68e254c2 all: update several golang.org/x packages (#10436) 2025-04-29 16:51:09 -07:00
Daniel Hiltgen
7bec2724a5 integration: fix embedding tests error handling (#10478)
The cleanup routine from InitServerconnection should run in the defer of the test case to properly detect failures and report the server logs
2025-04-29 11:57:54 -07:00
Jesse Gross
a27462b708 ollamarunner: Temporarily disable worst case graph preallocation
When we later have a large batch running purely on a CPU, this
results the error:
GGML_ASSERT(talloc->buffer_id >= 0)

Disabling this means that we will incrementally reallocate memory
as the graph grows.

Fixes #10410
2025-04-29 11:04:58 -07:00
crStiv
6bf0b8193a readme: fix typos (#10399) 2025-04-29 10:30:44 -07:00
Devon Rifkin
db428adbb8 Merge pull request #10468 from ollama/drifkin/num-parallel-1 2025-04-29 10:21:36 -07:00
Devon Rifkin
fe5b9bb21b lower default num parallel to 2
this is in part to "pay" for #10452, which doubled the default context length. The combination isn't fully neutral though, because even though the old 4x2k limit and the new 2x4k limit are memory equivalent, the 1x fallback is larger with 4k
2025-04-29 02:04:14 -07:00
Devon Rifkin
6ec71d8fb6 Merge pull request #10452 from ollama/drifkin/4096-context-length
config: update default context length to 4096
2025-04-28 17:13:51 -07:00
Devon Rifkin
44b466eeb2 config: update default context length to 4096 2025-04-28 17:03:27 -07:00
Devon Rifkin
a25f3f8260 Merge pull request #10451 from ollama/revert-10364-drifkin/context-length
Revert "increase default context length to 4096"
2025-04-28 17:02:10 -07:00
Devon Rifkin
dd93e1af85 Revert "increase default context length to 4096 (#10364)"
This reverts commit 424f648632.
2025-04-28 16:54:11 -07:00
Michael Yang
5cfc1c39f3 model: fix build (#10416) 2025-04-25 19:24:48 -07:00
Michael Yang
f0ad49ea17 memory 2025-04-25 16:59:20 -07:00
Michael Yang
7ba9fa9c7d fixes for maverick 2025-04-25 16:59:20 -07:00
Michael Yang
8bf11b84c1 chunked attention 2025-04-25 16:59:20 -07:00
Michael Yang
470af8ab89 connect vision to text 2025-04-25 16:59:20 -07:00
Michael Yang
178761aef3 image processing
Co-authored-by: Patrick Devine <patrick@infrahq.com>
2025-04-25 16:59:20 -07:00
Michael Yang
f0c66e6dea llama4 2025-04-25 16:59:20 -07:00
Michael Yang
54055a6dae fix test 2025-04-25 16:59:01 -07:00
Michael Yang
340448d2d1 explicitly decode maxarraysize 1024 2025-04-25 16:59:01 -07:00
Michael Yang
ced7d0e53d fix parameter count 2025-04-25 16:59:01 -07:00
Michael Yang
a0dba0f8ae default slice values 2025-04-25 16:59:01 -07:00
Michael Yang
5e20b170a7 update comment 2025-04-25 16:59:01 -07:00
Michael Yang
d26c18e25c fix token type 2025-04-25 16:59:01 -07:00
Michael Yang
8d376acc9b zero means zero
use a default of 1024 when asking for zero is confusing since most calls
seem to assume 0 means do not ready any data
2025-04-25 16:59:01 -07:00
Michael Yang
dc1e81f027 convert: use -1 for read all 2025-04-25 16:59:01 -07:00
Michael Yang
5d0279164c generic ggml.array 2025-04-25 16:59:01 -07:00
Michael Yang
214a7678ea fix superfluous call to WriteHeader
the first call to http.ResponseWriter.Write implicitly calls WriteHeader
with http.StatusOK if it hasn't already been called. once WriteHeader
has been called, subsequent calls has no effect. Write is called when
JSON encoding progressUpdateJSON{}. calls to
http.ResponseWriter.WriteHeader after the first encode is useless and
produces a warning:

http: superfluous response.WriteHeader call from github.com/ollama/ollama/server/internal/registry.(*statusCodeRecorder).WriteHeader (server.go:77)
2025-04-25 16:58:49 -07:00
Michael Yang
4892872c18 convert: change to colmajor 2025-04-25 15:27:39 -07:00
Michael Yang
0b9198bf47 ci: silence deprecated gpu targets warning 2025-04-25 13:37:54 -07:00
58 changed files with 2026 additions and 406 deletions

View File

@@ -21,14 +21,16 @@
"name": "CUDA 11",
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;70;75;80;86"
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;70;75;80;86",
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
}
},
{
"name": "CUDA 12",
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120"
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120",
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
}
},
{

View File

@@ -285,7 +285,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
- [Saddle](https://github.com/jikkuatwork/saddle)
- [TagSpaces](https://www.tagspaces.org) (A platform for file based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
- [TagSpaces](https://www.tagspaces.org) (A platform for file-based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
@@ -325,14 +325,14 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama)
- [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models)
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
- [Casibase](https://casibase.org) (An open source AI knowledge base and dialogue system combining the latest RAG, SSO, ollama support and multiple large language models.)
- [Casibase](https://casibase.org) (An open source AI knowledge base and dialogue system combining the latest RAG, SSO, ollama support, and multiple large language models.)
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
- [Shinkai Desktop](https://github.com/dcSpark/shinkai-apps) (Two click install Local AI using Ollama + Files + RAG)
- [AiLama](https://github.com/zeyoyt/ailama) (A Discord User App that allows you to interact with Ollama anywhere in discord )
- [AiLama](https://github.com/zeyoyt/ailama) (A Discord User App that allows you to interact with Ollama anywhere in Discord)
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
- [R2R](https://github.com/SciPhi-AI/R2R) (Open-source RAG engine)
- [Ollama-Kis](https://github.com/elearningshow/ollama-kis) (A simple easy to use GUI with sample custom LLM for Drivers Education)
- [Ollama-Kis](https://github.com/elearningshow/ollama-kis) (A simple easy-to-use GUI with sample custom LLM for Drivers Education)
- [OpenGPA](https://opengpa.org) (Open-source offline-first Enterprise Agentic Application)
- [Painting Droid](https://github.com/mateuszmigas/painting-droid) (Painting app with AI integrations)
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
@@ -341,16 +341,16 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LLMStack](https://github.com/trypromptly/LLMStack) (No-code multi-agent framework to build LLM agents and workflows)
- [BoltAI for Mac](https://boltai.com) (AI Chat Client for Mac)
- [Harbor](https://github.com/av/harbor) (Containerized LLM Toolkit with Ollama as default backend)
- [PyGPT](https://github.com/szczyglis-dev/py-gpt) (AI desktop assistant for Linux, Windows and Mac)
- [Alpaca](https://github.com/Jeffser/Alpaca) (An Ollama client application for linux and macos made with GTK4 and Adwaita)
- [PyGPT](https://github.com/szczyglis-dev/py-gpt) (AI desktop assistant for Linux, Windows, and Mac)
- [Alpaca](https://github.com/Jeffser/Alpaca) (An Ollama client application for Linux and macOS made with GTK4 and Adwaita)
- [AutoGPT](https://github.com/Significant-Gravitas/AutoGPT/blob/master/docs/content/platform/ollama.md) (AutoGPT Ollama integration)
- [Go-CREW](https://www.jonathanhecl.com/go-crew/) (Powerful Offline RAG in Golang)
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot and Ollama4j
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot, and Ollama4j
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VSCode extension for multi-file/whole-repo coding
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
- [crewAI with Mesop](https://github.com/rapidarchitect/ollama-crew-mesop) (Mesop Web Interface to run crewAI with Ollama)
- [Tkinter-based client](https://github.com/chyok/ollama-gui) (Python tkinter-based Client for Ollama)
@@ -368,7 +368,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
- [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol)
- [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app)
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard and said in the meetings)
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard, and said in the meetings)
- [Hexabot](https://github.com/hexastack/hexabot) (A conversational AI builder)
- [Reddit Rate](https://github.com/rapidarchitect/reddit_analyzer) (Search and Rate Reddit topics with a weighted summation)
- [OpenTalkGpt](https://github.com/adarshM84/OpenTalkGpt) (Chrome Extension to manage open-source models supported by Ollama, create custom models, and chat with models from a user-friendly UI)
@@ -386,7 +386,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [ChibiChat](https://github.com/CosmicEventHorizon/ChibiChat) (Kotlin-based Android app to chat with Ollama and Koboldcpp API endpoints)
- [LocalLLM](https://github.com/qusaismael/localllm) (Minimal Web-App to run ollama models on it with a GUI)
- [Ollamazing](https://github.com/buiducnhat/ollamazing) (Web extension to run Ollama models)
- [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivent endpoint with Ollama support for running locally)
- [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivalent endpoint with Ollama support for running locally)
- [AntSK](https://github.com/AIDotNet/AntSK) (Out-of-the-box & Adaptable RAG Chatbot)
- [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot)
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
@@ -399,6 +399,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Ollamb](https://github.com/hengkysteen/ollamb) (Simple yet rich in features, cross-platform built with Flutter and designed for Ollama. Try the [web demo](https://hengkysteen.github.io/demo/ollamb/).)
- [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama)
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
### Cloud
@@ -440,7 +441,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
- [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis.
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull and download models from Ollama Registry in your terminal.
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull, and download models from Ollama Registry in your terminal.
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
### Apple Vision Pro
@@ -515,7 +516,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Swollama for Swift](https://github.com/marcusziade/Swollama) with [DocC](https://marcusziade.github.io/Swollama/documentation/swollama/)
- [GoLamify](https://github.com/prasad89/golamify)
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in unified API)
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API)
- [LlmTornado](https://github.com/lofcz/llmtornado) (C# library providing a unified interface for major FOSS & Commercial inference APIs)
- [Ollama for Zig](https://github.com/dravenk/ollama-zig)
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
@@ -524,11 +525,11 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Mobile
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Lightning-fast Cross-platform AI chat app with native UI for Android, iOS and iPad)
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Lightning-fast Cross-platform AI chat app with native UI for Android, iOS, and iPad)
- [Enchanted](https://github.com/AugustDev/enchanted)
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid)
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
- [Ollama Android Chat](https://github.com/sunshine0523/OllamaServer) (No need for Termux, start the Ollama service with one click on an Android device)
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
@@ -552,7 +553,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Obsidian Local GPT plugin](https://github.com/pfrankov/obsidian-local-gpt)
- [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama)
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
- [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use ollama as a copilot like Github copilot)
- [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use Ollama as a copilot like GitHub Copilot)
- [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama)
- [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and Hugging Face)
- [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension)
@@ -562,8 +563,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation)
- [ChatGPTBox: All in one browser extension](https://github.com/josStorer/chatGPTBox) with [Integrating Tutorial](https://github.com/josStorer/chatGPTBox/issues/616#issuecomment-1975186467)
- [Discord AI chat/moderation bot](https://github.com/rapmd73/Companion) Chat/moderation bot written in python. Uses Ollama to create personalities.
- [Headless Ollama](https://github.com/nischalj10/headless-ollama) (Scripts to automatically install ollama client & models on any OS for apps that depends on ollama server)
- [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front end Open WebUI service.)
- [Headless Ollama](https://github.com/nischalj10/headless-ollama) (Scripts to automatically install ollama client & models on any OS for apps that depend on ollama server)
- [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front-end Open WebUI service.)
- [node-red-contrib-ollama](https://github.com/jakubburkiewicz/node-red-contrib-ollama)
- [Local AI Helper](https://github.com/ivostoykov/localAI) (Chrome and Firefox extensions that enable interactions with the active tab and customisable API endpoints. Includes secure storage for user prompts.)
- [vnc-lm](https://github.com/jake83741/vnc-lm) (Discord bot for messaging with LLMs through Ollama and LiteLLM. Seamlessly move between local and flagship models.)

View File

@@ -22,7 +22,6 @@ import (
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
@@ -32,7 +31,6 @@ import (
"github.com/olekukonko/tablewriter"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
"golang.org/x/term"
"github.com/ollama/ollama/api"
@@ -108,7 +106,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
}
spinner.Stop()
req.Model = args[0]
req.Name = args[0]
quantize, _ := cmd.Flags().GetString("quantize")
if quantize != "" {
req.Quantize = quantize
@@ -119,43 +117,26 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err
}
var mu sync.Mutex
var g errgroup.Group
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
// copy files since we'll be modifying the map
temp := req.Files
req.Files = make(map[string]string, len(temp))
for f, digest := range temp {
g.Go(func() error {
if len(req.Files) > 0 {
fileMap := map[string]string{}
for f, digest := range req.Files {
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
return err
}
mu.Lock()
req.Files[filepath.Base(f)] = digest
mu.Unlock()
return nil
})
fileMap[filepath.Base(f)] = digest
}
req.Files = fileMap
}
// copy files since we'll be modifying the map
temp = req.Adapters
req.Adapters = make(map[string]string, len(temp))
for f, digest := range temp {
g.Go(func() error {
if len(req.Adapters) > 0 {
fileMap := map[string]string{}
for f, digest := range req.Adapters {
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
return err
}
mu.Lock()
req.Adapters[filepath.Base(f)] = digest
mu.Unlock()
return nil
})
}
if err := g.Wait(); err != nil {
return err
fileMap[filepath.Base(f)] = digest
}
req.Adapters = fileMap
}
bars := make(map[string]*progress.Bar)
@@ -232,7 +213,7 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string, digest stri
}
}()
if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err
}
return digest, nil
@@ -1426,7 +1407,6 @@ func NewCLI() *cobra.Command {
envVars["OLLAMA_LLM_LIBRARY"],
envVars["OLLAMA_GPU_OVERHEAD"],
envVars["OLLAMA_LOAD_TIMEOUT"],
envVars["OLLAMA_CONTEXT_LENGTH"],
})
default:
appendEnvDocs(cmd, envs)

View File

@@ -690,7 +690,7 @@ func TestCreateHandler(t *testing.T) {
return
}
if req.Model != "test-model" {
if req.Name != "test-model" {
t.Errorf("expected model name 'test-model', got %s", req.Name)
}

View File

@@ -4,9 +4,10 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"log/slog"
"os"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
@@ -84,14 +85,6 @@ func (ModelParameters) specialTokenTypes() []string {
}
}
func (ModelParameters) writeFile(f *os.File, kv ggml.KV, ts []ggml.Tensor) error {
return ggml.WriteGGUF(f, kv, ts)
}
func (AdapterParameters) writeFile(f *os.File, kv ggml.KV, ts []ggml.Tensor) error {
return ggml.WriteGGUF(f, kv, ts)
}
type ModelConverter interface {
// KV maps parameters to LLM key-values
KV(*Tokenizer) ggml.KV
@@ -103,8 +96,6 @@ type ModelConverter interface {
// specialTokenTypes returns any special token types the model uses
specialTokenTypes() []string
// writeFile writes the model to the provided io.WriteSeeker
writeFile(*os.File, ggml.KV, []ggml.Tensor) error
}
type moreParser interface {
@@ -119,11 +110,9 @@ type AdapterConverter interface {
// Replacements returns a list of string pairs to replace in tensor names.
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
Replacements() []string
writeFile(*os.File, ggml.KV, []ggml.Tensor) error
}
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV ggml.KV) error {
bts, err := fs.ReadFile(fsys, "adapter_config.json")
if err != nil {
return err
@@ -158,14 +147,14 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
return err
}
return conv.writeFile(f, conv.KV(baseKV), conv.Tensors(ts))
return writeFile(ws, conv.KV(baseKV), conv.Tensors(ts))
}
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error {
func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
bts, err := fs.ReadFile(fsys, "config.json")
if err != nil {
return err
@@ -184,6 +173,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
switch p.Architectures[0] {
case "LlamaForCausalLM":
conv = &llamaModel{}
case "Llama4ForConditionalGeneration":
conv = &llama4Model{}
case "Mistral3ForConditionalGeneration":
conv = &mistral3Model{}
case "MixtralForCausalLM":
@@ -248,5 +239,13 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
return err
}
return conv.writeFile(f, conv.KV(t), conv.Tensors(ts))
return writeFile(ws, conv.KV(t), conv.Tensors(ts))
}
func writeFile(ws io.WriteSeeker, kv ggml.KV, ts []ggml.Tensor) error {
for i := range ts {
ts[i].Shape = slices.Clone(ts[i].Shape)
slices.Reverse(ts[i].Shape)
}
return ggml.WriteGGUF(ws, kv, ts)
}

View File

@@ -42,6 +42,8 @@ type llamaModel struct {
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
NormEpsilon float32 `json:"norm_epsilon"`
HeadDim uint32 `json:"head_dim"`
skipRepack bool
}
var _ ModelConverter = (*llamaModel)(nil)
@@ -70,6 +72,10 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
kv["llama.rope.dimension_count"] = p.HiddenSize / headCount
}
if p.HeadDim > 0 {
kv["llama.attention.head_dim"] = p.HeadDim
}
if p.RopeTheta > 0 {
kv["llama.rope.freq_base"] = p.RopeTheta
}
@@ -133,9 +139,10 @@ func (p *llamaModel) Tensors(ts []Tensor) []ggml.Tensor {
}
for _, t := range ts {
if strings.HasSuffix(t.Name(), "attn_q.weight") ||
strings.HasSuffix(t.Name(), "attn_k.weight") {
t.SetRepacker(p.repack)
if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") {
if !p.skipRepack {
t.SetRepacker(p.repack)
}
}
out = append(out, ggml.Tensor{

169
convert/convert_llama4.go Normal file
View File

@@ -0,0 +1,169 @@
package convert
import (
"slices"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type llama4Model struct {
ModelParameters
TextModel struct {
llamaModel
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
NumLocalExperts uint32 `json:"num_local_experts"`
InterleaveMOELayerStep uint32 `json:"interleave_moe_layer_step"`
UseQKNorm bool `json:"use_qk_norm"`
IntermediateSizeMLP uint32 `json:"intermediate_size_mlp"`
AttentionChunkSize uint32 `json:"attention_chunk_size"`
} `json:"text_config"`
VisionModel struct {
NumHiddenLayers uint32 `json:"num_hidden_layers"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
ImageSize uint32 `json:"image_size"`
PatchSize uint32 `json:"patch_size"`
RopeTheta float32 `json:"rope_theta"`
NormEpsilon float32 `json:"norm_eps"`
PixelShuffleRatio float32 `json:"pixel_shuffle_ratio"`
} `json:"vision_config"`
}
// KV implements ModelConverter.
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama4"
for k, v := range p.TextModel.KV(t) {
if strings.HasPrefix(k, "llama.") {
kv[strings.ReplaceAll(k, "llama.", "llama4.")] = v
}
}
kv["llama4.feed_forward_length"] = p.TextModel.IntermediateSizeMLP
kv["llama4.expert_feed_forward_length"] = p.TextModel.IntermediateSize
kv["llama4.expert_count"] = p.TextModel.NumLocalExperts
kv["llama4.expert_used_count"] = p.TextModel.NumExpertsPerToken
kv["llama4.interleave_moe_layer_step"] = p.TextModel.InterleaveMOELayerStep
kv["llama4.use_qk_norm"] = p.TextModel.UseQKNorm
kv["llama4.attention.chunk_size"] = p.TextModel.AttentionChunkSize
kv["llama4.vision.block_count"] = p.VisionModel.NumHiddenLayers
kv["llama4.vision.embedding_length"] = p.VisionModel.HiddenSize
kv["llama4.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
kv["llama4.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
kv["llama4.vision.image_size"] = p.VisionModel.ImageSize
kv["llama4.vision.patch_size"] = p.VisionModel.PatchSize
kv["llama4.vision.rope.freq_base"] = p.VisionModel.RopeTheta
kv["llama4.vision.layer_norm_epsilon"] = p.VisionModel.NormEpsilon
kv["llama4.vision.pixel_shuffle_ratio"] = p.VisionModel.PixelShuffleRatio
return kv
}
// Replacements implements ModelConverter.
func (p *llama4Model) Replacements() []string {
return append(
p.TextModel.Replacements(),
"language_model.", "",
"vision_model", "v",
"multi_modal_projector", "mm",
"feed_forward.down_proj", "ffn_down",
"feed_forward.up_proj", "ffn_up",
"feed_forward.gate_proj", "ffn_gate",
"feed_forward.", "ffn_",
"shared_expert.down_proj", "down_shexp",
"shared_expert.gate_proj", "gate_shexp",
"shared_expert.up_proj", "up_shexp",
"experts.down_proj", "down_exps.weight",
"experts.gate_up_proj", "gate_up_exps.weight",
"router", "gate_inp",
"patch_embedding.linear", "patch_embedding",
)
}
// Tensors implements ModelConverter.
func (p *llama4Model) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
var textTensors []Tensor
for _, t := range ts {
if strings.HasPrefix(t.Name(), "v.") || strings.HasPrefix(t.Name(), "mm.") {
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
} else if strings.Contains(t.Name(), "ffn_gate_up_exps") {
// gate and up projectors are fused
// dims[1], dims[2] must be swapped
// [experts, hidden_size, intermediate_size * 2] --> [experts, intermediate_size, hidden_size]
halfDim := int(t.Shape()[2]) / 2
newShape := slices.Clone(t.Shape())
newShape[1], newShape[2] = newShape[2]/2, newShape[1]
for i, name := range []string{"ffn_gate_exps", "ffn_up_exps"} {
// clone tensor since we need separate repackers
tt := t.Clone()
tt.SetRepacker(p.repack(nil, nil, tensor.S(i*halfDim, (i+1)*halfDim)))
out = append(out, ggml.Tensor{
Name: strings.ReplaceAll(tt.Name(), "ffn_gate_up_exps", name),
Kind: tt.Kind(),
Shape: newShape,
WriterTo: tt,
})
}
} else if strings.Contains(t.Name(), "ffn_down_exps") {
// dims[1], dims[2] must be swapped
// [experts, intermediate_size, hidden_size] --> [experts, hidden_size, intermediate_size]
t.SetRepacker(p.repack())
newShape := slices.Clone(t.Shape())
newShape[1], newShape[2] = newShape[2], newShape[1]
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: newShape,
WriterTo: t,
})
} else {
textTensors = append(textTensors, t)
}
}
p.TextModel.skipRepack = true
out = append(out, p.TextModel.Tensors(textTensors)...)
return out
}
func (p *llama4Model) repack(slice ...tensor.Slice) Repacker {
return func(name string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i, dim := range shape {
dims[i] = int(dim)
}
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
t, err := t.Slice(slice...)
if err != nil {
return nil, err
}
if err := t.T(0, 2, 1); err != nil {
return nil, err
}
t = tensor.Materialize(t)
// flatten tensor so it can be return as a vector
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(t.(*tensor.Dense))
}
}

View File

@@ -11,7 +11,6 @@ import (
"io"
"io/fs"
"log/slog"
"math"
"os"
"path/filepath"
"slices"
@@ -48,7 +47,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
}
t.Cleanup(func() { r.Close() })
m, _, err := ggml.Decode(r, math.MaxInt)
m, _, err := ggml.Decode(r, -1)
if err != nil {
t.Fatal(err)
}
@@ -332,7 +331,7 @@ func TestConvertAdapter(t *testing.T) {
}
defer r.Close()
m, _, err := ggml.Decode(r, math.MaxInt)
m, _, err := ggml.Decode(r, -1)
if err != nil {
t.Fatal(err)
}

View File

@@ -11,14 +11,15 @@ type Tensor interface {
Name() string
Shape() []uint64
Kind() uint32
SetRepacker(repacker)
SetRepacker(Repacker)
WriteTo(io.Writer) (int64, error)
Clone() Tensor
}
type tensorBase struct {
name string
shape []uint64
repacker
name string
shape []uint64
repacker Repacker
}
func (t tensorBase) Name() string {
@@ -36,7 +37,8 @@ const (
func (t tensorBase) Kind() uint32 {
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
t.name == "token_types.weight" {
t.name == "token_types.weight" ||
t.name == "v.positional_embedding_vlm" {
// these tensors are always F32
return 0
}
@@ -51,11 +53,11 @@ func (t tensorBase) Kind() uint32 {
}
}
func (t *tensorBase) SetRepacker(fn repacker) {
func (t *tensorBase) SetRepacker(fn Repacker) {
t.repacker = fn
}
type repacker func(string, []float32, []uint64) ([]float32, error)
type Repacker func(string, []float32, []uint64) ([]float32, error)
func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) {
patterns := []struct {

View File

@@ -94,6 +94,21 @@ type safetensor struct {
*tensorBase
}
func (st safetensor) Clone() Tensor {
return &safetensor{
fs: st.fs,
path: st.path,
dtype: st.dtype,
offset: st.offset,
size: st.size,
tensorBase: &tensorBase{
name: st.name,
repacker: st.repacker,
shape: slices.Clone(st.shape),
},
}
}
func (st safetensor) WriteTo(w io.Writer) (int64, error) {
f, err := st.fs.Open(st.path)
if err != nil {

View File

@@ -43,6 +43,17 @@ type torch struct {
*tensorBase
}
func (t torch) Clone() Tensor {
return torch{
storage: t.storage,
tensorBase: &tensorBase{
name: t.name,
shape: t.shape,
repacker: t.repacker,
},
}
}
func (pt torch) WriteTo(w io.Writer) (int64, error) {
return 0, nil
}

View File

@@ -20,7 +20,7 @@ Please refer to the [GPU docs](./gpu.md).
## How can I specify the context window size?
By default, Ollama uses a context window size of 4096 tokens, unless you have a single GPU with <= 4 GB of VRAM, in which case it will default to 2048 tokens.
By default, Ollama uses a context window size of 4096 tokens.
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
@@ -31,7 +31,7 @@ OLLAMA_CONTEXT_LENGTH=8192 ollama serve
To change this when using `ollama run`, use `/set parameter`:
```shell
/set parameter num_ctx 8192
/set parameter num_ctx 4096
```
When using the API, specify the `num_ctx` parameter:
@@ -41,7 +41,7 @@ curl http://localhost:11434/api/generate -d '{
"model": "llama3.2",
"prompt": "Why is the sky blue?",
"options": {
"num_ctx": 8192
"num_ctx": 4096
}
}'
```

View File

@@ -169,7 +169,7 @@ var (
// Enable the new Ollama engine
NewEngine = Bool("OLLAMA_NEW_ENGINE")
// ContextLength sets the default context length
ContextLength = Int64("OLLAMA_CONTEXT_LENGTH", -1)
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
)
func String(s string) func() string {
@@ -227,20 +227,6 @@ func Uint64(key string, defaultValue uint64) func() uint64 {
}
}
func Int64(key string, defaultValue int64) func() int64 {
return func() int64 {
if s := Var(key); s != "" {
if n, err := strconv.ParseInt(s, 10, 64); err != nil {
slog.Warn("invalid environment variable, using default", "key", key, "value", s, "default", defaultValue)
} else {
return n
}
}
return defaultValue
}
}
// Set aside VRAM per GPU
var GpuOverhead = Uint64("OLLAMA_GPU_OVERHEAD", 0)
@@ -269,7 +255,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default 4096 or 2048 with low VRAM)"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
// Informational

View File

@@ -278,9 +278,9 @@ func TestVar(t *testing.T) {
}
func TestContextLength(t *testing.T) {
cases := map[string]int64{
"": -1,
"4096": 4096,
cases := map[string]uint{
"": 4096,
"2048": 2048,
}
for k, v := range cases {

View File

@@ -8,6 +8,6 @@ type Config interface {
Bool(string, ...bool) bool
Strings(string, ...[]string) []string
Uints(string, ...[]uint32) []uint32
Ints(string, ...[]int32) []int32
Floats(string, ...[]float32) []float32
}

View File

@@ -33,7 +33,7 @@ func (kv KV) Kind() string {
}
func (kv KV) ParameterCount() uint64 {
return keyValue[uint64](kv, "general.parameter_count")
return keyValue(kv, "general.parameter_count", uint64(0))
}
func (kv KV) FileType() fileType {
@@ -105,42 +105,42 @@ func (kv KV) Bool(key string, defaultValue ...bool) bool {
}
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
r := keyValue(kv, key, &array{})
s := make([]string, r.size)
for i := range r.size {
s[i] = r.values[i].(string)
}
return keyValue(kv, key, &array[string]{values: append(defaultValue, []string(nil))[0]}).values
}
return s
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
return keyValue(kv, key, &array[int32]{values: append(defaultValue, []int32(nil))[0]}).values
}
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
r := keyValue(kv, key, &array{})
s := make([]uint32, r.size)
for i := range r.size {
s[i] = uint32(r.values[i].(int32))
}
return s
return keyValue(kv, key, &array[uint32]{values: append(defaultValue, []uint32(nil))[0]}).values
}
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
r := keyValue(kv, key, &array{})
s := make([]float32, r.size)
for i := range r.size {
s[i] = float32(r.values[i].(float32))
}
return s
return keyValue(kv, key, &array[float32]{values: append(defaultValue, []float32(nil))[0]}).values
}
func (kv KV) OllamaEngineRequired() bool {
return slices.Contains([]string{
"gemma3",
"mistral3",
"llama4",
}, kv.Architecture())
}
func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T {
type valueTypes interface {
uint8 | int8 | uint16 | int16 |
uint32 | int32 | uint64 | int64 |
string | float32 | float64 | bool
}
type arrayValueTypes interface {
*array[uint8] | *array[int8] | *array[uint16] | *array[int16] |
*array[uint32] | *array[int32] | *array[uint64] | *array[int64] |
*array[string] | *array[float32] | *array[float64] | *array[bool]
}
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) T {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key
}
@@ -375,13 +375,8 @@ func DetectContentType(b []byte) string {
// Decode decodes a GGML model from the given reader.
//
// It collects array values for arrays with a size less than or equal to
// maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
// the maxArraySize is negative, all arrays are collected.
// maxArraySize. If the maxArraySize is negative, all arrays are collected.
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
if maxArraySize == 0 {
maxArraySize = 1024
}
rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
var magic uint32
@@ -420,7 +415,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
embedding := f.KV().EmbeddingLength()
heads := f.KV().HeadCount()
headsKV := f.KV().HeadCountKV()
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array).size)
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size)
embeddingHeads := f.KV().EmbeddingHeadCount()
embeddingHeadsK := f.KV().EmbeddingHeadCountK()
@@ -435,7 +430,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
}
switch f.KV().Architecture() {
case "llama":
case "llama", "llama4":
fullOffload = max(
4*batch*(1+4*embedding+context*(1+heads)),
4*batch*(embedding+vocab),
@@ -449,7 +444,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
// mixtral 8x22b
ff := uint64(f.KV()["llama.feed_forward_length"].(uint32))
ff := uint64(f.KV().Uint("feed_forward_length"))
partialOffload = max(
3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
@@ -466,9 +461,9 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
case "mllama":
var visionTokens, tiles uint64 = 1601, 4
crossAttentionLayers := f.KV().Uints("attention.cross_attention_layers")
crossAttentionLayers := f.KV().Ints("attention.cross_attention_layers")
for i := range kv {
if slices.Contains(crossAttentionLayers, uint32(i)) {
if slices.Contains(crossAttentionLayers, int32(i)) {
kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) *
4 * // sizeof(float32)
visionTokens *
@@ -645,6 +640,9 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
graphSize = 4 * (imageSize*imageSize*numChannels +
embeddingLength*patchSize +
numPatches*numPatches*headCount)
case "llama4":
// vision graph is computed independently in the same schedule
// and is negligible compared to the worst case text graph
}
return weights, graphSize

View File

@@ -2,6 +2,7 @@ package ggml
import (
"maps"
"math"
"slices"
"strconv"
"strings"
@@ -210,3 +211,61 @@ func TestTensorTypes(t *testing.T) {
})
}
}
func TestKeyValue(t *testing.T) {
kv := KV{
"general.architecture": "test",
"test.strings": &array[string]{size: 3, values: []string{"a", "b", "c"}},
"test.float32s": &array[float32]{size: 3, values: []float32{1.0, 2.0, 3.0}},
"test.int32s": &array[int32]{size: 3, values: []int32{1, 2, 3}},
"test.uint32s": &array[uint32]{size: 3, values: []uint32{1, 2, 3}},
}
if diff := cmp.Diff(kv.Strings("strings"), []string{"a", "b", "c"}); diff != "" {
t.Errorf("unexpected strings (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Strings("nonexistent.strings"), []string(nil)); diff != "" {
t.Errorf("unexpected strings (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Strings("default.strings", []string{"ollama"}), []string{"ollama"}); diff != "" {
t.Errorf("unexpected strings (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Floats("float32s"), []float32{1.0, 2.0, 3.0}); diff != "" {
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Floats("nonexistent.float32s"), []float32(nil)); diff != "" {
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Floats("default.float32s", []float32{math.MaxFloat32}), []float32{math.MaxFloat32}); diff != "" {
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Ints("int32s"), []int32{1, 2, 3}); diff != "" {
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Ints("nonexistent.int32s"), []int32(nil)); diff != "" {
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Ints("default.int32s", []int32{math.MaxInt32}), []int32{math.MaxInt32}); diff != "" {
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Uints("uint32s"), []uint32{1, 2, 3}); diff != "" {
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Uints("nonexistent.uint32s"), []uint32(nil)); diff != "" {
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Uints("default.uint32s", []uint32{math.MaxUint32}), []uint32{math.MaxUint32}); diff != "" {
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
}
}

View File

@@ -9,12 +9,8 @@ import (
"io"
"log/slog"
"maps"
"os"
"runtime"
"slices"
"strings"
"golang.org/x/sync/errgroup"
)
type containerGGUF struct {
@@ -40,10 +36,6 @@ type containerGGUF struct {
maxArraySize int
}
func (c *containerGGUF) canCollectArray(size int) bool {
return c.maxArraySize < 0 || size <= c.maxArraySize
}
func (c *containerGGUF) Name() string {
return "gguf"
}
@@ -299,6 +291,23 @@ func readGGUFV1String(llm *gguf, r io.Reader) (string, error) {
return b.String(), nil
}
func readGGUFV1StringsData(llm *gguf, r io.Reader, a *array[string]) (any, error) {
for i := range a.size {
if a.values != nil {
e, err := readGGUFV1String(llm, r)
if err != nil {
return nil, err
}
a.values[i] = e
} else {
discardGGUFString(llm, r)
}
}
return a, nil
}
func discardGGUFString(llm *gguf, r io.Reader) error {
buf := llm.scratch[:8]
_, err := io.ReadFull(r, buf)
@@ -356,78 +365,44 @@ func writeGGUFString(w io.Writer, s string) error {
return err
}
type array struct {
size int
values []any
}
func (a *array) MarshalJSON() ([]byte, error) {
return json.Marshal(a.values)
}
func readGGUFV1Array(llm *gguf, r io.Reader) (*array, error) {
t, err := readGGUF[uint32](llm, r)
if err != nil {
return nil, err
}
n, err := readGGUF[uint32](llm, r)
if err != nil {
return nil, err
}
a := &array{size: int(n)}
if llm.canCollectArray(int(n)) {
a.values = make([]any, 0, int(n))
}
for i := range n {
var e any
switch t {
case ggufTypeUint8:
e, err = readGGUF[uint8](llm, r)
case ggufTypeInt8:
e, err = readGGUF[int8](llm, r)
case ggufTypeUint16:
e, err = readGGUF[uint16](llm, r)
case ggufTypeInt16:
e, err = readGGUF[int16](llm, r)
case ggufTypeUint32:
e, err = readGGUF[uint32](llm, r)
case ggufTypeInt32:
e, err = readGGUF[int32](llm, r)
case ggufTypeUint64:
e, err = readGGUF[uint64](llm, r)
case ggufTypeInt64:
e, err = readGGUF[int64](llm, r)
case ggufTypeFloat32:
e, err = readGGUF[float32](llm, r)
case ggufTypeFloat64:
e, err = readGGUF[float64](llm, r)
case ggufTypeBool:
e, err = readGGUF[bool](llm, r)
case ggufTypeString:
e, err = readGGUFV1String(llm, r)
default:
return nil, fmt.Errorf("invalid array type: %d", t)
}
if err != nil {
return nil, err
}
func readGGUFStringsData(llm *gguf, r io.Reader, a *array[string]) (any, error) {
for i := range a.size {
if a.values != nil {
e, err := readGGUFString(llm, r)
if err != nil {
return nil, err
}
a.values[i] = e
} else {
discardGGUFString(llm, r)
}
}
return a, nil
}
func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
if llm.Version == 1 {
return readGGUFV1Array(llm, r)
}
type array[T any] struct {
// size is the actual size of the array
size int
// values is the array of values. this is nil if the array is larger than configured maxSize
values []T
}
func (a *array[T]) MarshalJSON() ([]byte, error) {
return json.Marshal(a.values)
}
func newArray[T any](size, maxSize int) *array[T] {
a := array[T]{size: size}
if maxSize < 0 || size <= maxSize {
a.values = make([]T, size)
}
return &a
}
func readGGUFArray(llm *gguf, r io.Reader) (any, error) {
t, err := readGGUF[uint32](llm, r)
if err != nil {
return nil, err
@@ -438,45 +413,55 @@ func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
return nil, err
}
a := &array{size: int(n)}
if llm.canCollectArray(int(n)) {
a.values = make([]any, int(n))
}
for i := range n {
var e any
switch t {
case ggufTypeUint8:
e, err = readGGUF[uint8](llm, r)
case ggufTypeInt8:
e, err = readGGUF[int8](llm, r)
case ggufTypeUint16:
e, err = readGGUF[uint16](llm, r)
case ggufTypeInt16:
e, err = readGGUF[int16](llm, r)
case ggufTypeUint32:
e, err = readGGUF[uint32](llm, r)
case ggufTypeInt32:
e, err = readGGUF[int32](llm, r)
case ggufTypeUint64:
e, err = readGGUF[uint64](llm, r)
case ggufTypeInt64:
e, err = readGGUF[int64](llm, r)
case ggufTypeFloat32:
e, err = readGGUF[float32](llm, r)
case ggufTypeFloat64:
e, err = readGGUF[float64](llm, r)
case ggufTypeBool:
e, err = readGGUF[bool](llm, r)
case ggufTypeString:
if a.values != nil {
e, err = readGGUFString(llm, r)
} else {
err = discardGGUFString(llm, r)
}
default:
return nil, fmt.Errorf("invalid array type: %d", t)
switch t {
case ggufTypeUint8:
a := newArray[uint8](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeInt8:
a := newArray[int8](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeUint16:
a := newArray[uint16](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeInt16:
a := newArray[int16](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeUint32:
a := newArray[uint32](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeInt32:
a := newArray[int32](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeUint64:
a := newArray[uint64](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeInt64:
a := newArray[int64](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeFloat32:
a := newArray[float32](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeFloat64:
a := newArray[float64](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeBool:
a := newArray[bool](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeString:
a := newArray[string](int(n), llm.maxArraySize)
if llm.Version == 1 {
return readGGUFV1StringsData(llm, r, a)
}
return readGGUFStringsData(llm, r, a)
default:
return nil, fmt.Errorf("invalid array type: %d", t)
}
}
func readGGUFArrayData[T any](llm *gguf, r io.Reader, a *array[T]) (any, error) {
for i := range a.size {
e, err := readGGUF[T](llm, r)
if err != nil {
return nil, err
}
@@ -506,22 +491,22 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
return binary.Write(w, binary.LittleEndian, s)
}
func WriteGGUF(f *os.File, kv KV, ts []Tensor) error {
func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
alignment := kv.Uint("general.alignment", 32)
if err := binary.Write(f, binary.LittleEndian, []byte("GGUF")); err != nil {
if err := binary.Write(ws, binary.LittleEndian, []byte("GGUF")); err != nil {
return err
}
if err := binary.Write(f, binary.LittleEndian, uint32(3)); err != nil {
if err := binary.Write(ws, binary.LittleEndian, uint32(3)); err != nil {
return err
}
if err := binary.Write(f, binary.LittleEndian, uint64(len(ts))); err != nil {
if err := binary.Write(ws, binary.LittleEndian, uint64(len(ts))); err != nil {
return err
}
if err := binary.Write(f, binary.LittleEndian, uint64(len(kv))); err != nil {
if err := binary.Write(ws, binary.LittleEndian, uint64(len(kv))); err != nil {
return err
}
@@ -529,7 +514,7 @@ func WriteGGUF(f *os.File, kv KV, ts []Tensor) error {
slices.Sort(keys)
for _, key := range keys {
if err := ggufWriteKV(f, key, kv[key]); err != nil {
if err := ggufWriteKV(ws, key, kv[key]); err != nil {
return err
}
}
@@ -545,34 +530,21 @@ func WriteGGUF(f *os.File, kv KV, ts []Tensor) error {
})
var s uint64
for i := range ts {
ts[i].Offset = s + uint64(ggufPadding(int64(s), int64(alignment)))
if err := ggufWriteTensorInfo(f, ts[i]); err != nil {
for _, t := range ts {
t.Offset = s + uint64(ggufPadding(int64(s), int64(alignment)))
if err := ggufWriteTensorInfo(ws, t); err != nil {
return err
}
s += ts[i].Size()
s += t.Size()
}
offset, err := f.Seek(0, io.SeekCurrent)
if err != nil {
return err
}
offset += ggufPadding(offset, int64(alignment))
slog.Debug("gguf", "offset", offset, "size", s, "alignment", alignment)
var g errgroup.Group
g.SetLimit(runtime.GOMAXPROCS(0))
for _, t := range ts {
t := t
w := io.NewOffsetWriter(f, offset+int64(t.Offset))
g.Go(func() error {
_, err := t.WriteTo(w)
if err := ggufWriteTensor(ws, t, int64(alignment)); err != nil {
return err
})
}
}
return g.Wait()
return nil
}
func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
@@ -644,8 +616,8 @@ func ggufWriteTensorInfo(ws io.WriteSeeker, t Tensor) error {
return err
}
for i := range len(t.Shape) {
if err := binary.Write(ws, binary.LittleEndian, t.Shape[len(t.Shape)-i-1]); err != nil {
for _, n := range t.Shape {
if err := binary.Write(ws, binary.LittleEndian, n); err != nil {
return err
}
}
@@ -657,6 +629,20 @@ func ggufWriteTensorInfo(ws io.WriteSeeker, t Tensor) error {
return binary.Write(ws, binary.LittleEndian, t.Offset)
}
func ggufWriteTensor(ws io.WriteSeeker, t Tensor, alignment int64) error {
offset, err := ws.Seek(0, io.SeekCurrent)
if err != nil {
return err
}
if err := binary.Write(ws, binary.LittleEndian, bytes.Repeat([]byte{0}, int(ggufPadding(offset, alignment)))); err != nil {
return err
}
_, err = t.WriteTo(ws)
return err
}
func ggufPadding(offset, align int64) int64 {
return (align - offset%align) % align
}

12
go.mod
View File

@@ -11,7 +11,7 @@ require (
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.9.0
github.com/x448/float16 v0.8.4
golang.org/x/sync v0.11.0
golang.org/x/sync v0.12.0
)
require (
@@ -70,12 +70,12 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.33.0
golang.org/x/crypto v0.36.0
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
golang.org/x/net v0.35.0 // indirect
golang.org/x/sys v0.30.0
golang.org/x/term v0.29.0
golang.org/x/text v0.22.0
golang.org/x/net v0.38.0 // indirect
golang.org/x/sys v0.31.0
golang.org/x/term v0.30.0
golang.org/x/text v0.23.0
google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1 // indirect
)

24
go.sum
View File

@@ -214,8 +214,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@@ -257,8 +257,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -268,8 +268,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -285,17 +285,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s=
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@@ -34,13 +34,15 @@ func cosineSimilarity[V float32 | float64](v1, v2 []V) V {
func TestAllMiniLMEmbeddings(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req := api.EmbeddingRequest{
Model: "all-minilm",
Prompt: "why is the sky blue?",
}
res, err := embeddingTestHelper(ctx, t, req)
res, err := embeddingTestHelper(ctx, client, t, req)
if err != nil {
t.Fatalf("error: %v", err)
@@ -62,13 +64,15 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
func TestAllMiniLMEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req := api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
}
res, err := embedTestHelper(ctx, t, req)
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatalf("error: %v", err)
@@ -98,13 +102,15 @@ func TestAllMiniLMEmbed(t *testing.T) {
func TestAllMiniLMBatchEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req := api.EmbedRequest{
Model: "all-minilm",
Input: []string{"why is the sky blue?", "why is the grass green?"},
}
res, err := embedTestHelper(ctx, t, req)
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatalf("error: %v", err)
@@ -144,6 +150,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
func TestAllMiniLMEmbedTruncate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
truncTrue, truncFalse := true, false
@@ -182,7 +190,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
res := make(map[string]*api.EmbedResponse)
for _, req := range reqs {
response, err := embedTestHelper(ctx, t, req.Request)
response, err := embedTestHelper(ctx, client, t, req.Request)
if err != nil {
t.Fatalf("error: %v", err)
}
@@ -198,7 +206,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
}
// check that truncate set to false returns an error if context length is exceeded
_, err := embedTestHelper(ctx, t, api.EmbedRequest{
_, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncFalse,
@@ -210,9 +218,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
}
}
func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err)
}
@@ -226,9 +232,7 @@ func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingReq
return response, nil
}
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err)
}

View File

@@ -21,6 +21,7 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
type Causal struct {
DType ml.DType
windowSize int32
chunkSize int32
opts CausalOptions
@@ -97,6 +98,17 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
}
}
func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
return &Causal{
windowSize: math.MaxInt32,
chunkSize: chunkSize,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
}
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
if c.config == nil {
var config ml.CacheConfig
@@ -300,6 +312,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
(enabled && c.cells[j].pos > c.curPositions[i]) ||
c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
c.cells[j].pos < c.curPositions[i]-c.windowSize {
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
}

View File

@@ -86,6 +86,64 @@ func TestSWA(t *testing.T) {
testCache(t, backend, cache, tests)
}
func TestChunkedAttention(t *testing.T) {
cache := NewChunkedAttentionCache(2, nil)
defer cache.Close()
var b testBackend
cache.Init(&b, ml.DTypeF16, 1, 16, 16)
x := float32(math.Inf(-1))
testCache(
t, &b, cache,
[]testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{
0, x, x, x,
0, 0, x, x,
x, x, 0, x,
x, x, 0, 0,
},
},
{
name: "SecondBatch",
in: []float32{5, 6, 7},
inShape: []int{1, 1, 3},
seqs: []int{0, 0, 0},
pos: []int32{4, 5, 6},
expected: []float32{1, 2, 3, 4, 5, 6, 7},
expectedShape: []int{1, 1, 7},
expectedMask: []float32{
x, x, x, x, 0, x, x,
x, x, x, x, 0, 0, x,
x, x, x, x, x, x, 0,
},
},
{
name: "ThirdBatch",
in: []float32{8, 9},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{7, 8},
expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
expectedShape: []int{1, 1, 9},
expectedMask: []float32{
x, x, x, x, x, x, 0, 0, x,
x, x, x, x, x, x, x, x, 0,
},
},
},
)
}
func TestSequences(t *testing.T) {
backend := &testBackend{}
cache := NewCausalCache(nil)
@@ -293,8 +351,16 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
context.Forward(out, mask).Compute(out, mask)
if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) {
t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask)
if !slices.Equal(out.Floats(), test.expected) {
t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected)
}
if !slices.Equal(out.Shape(), test.expectedShape) {
t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape)
}
if !slices.Equal(mask.Floats(), test.expectedMask) {
t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask)
}
})
}

View File

@@ -414,7 +414,7 @@ func projectorMemoryRequirements(filename string) (weights, graphSize uint64) {
}
defer file.Close()
ggml, _, err := ggml.Decode(file, 0)
ggml, _, err := ggml.Decode(file, 1024)
if err != nil {
return 0, 0
}

View File

@@ -329,11 +329,13 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
libraryPaths = append(libraryPaths, filepath.SplitList(libraryPath)...)
}
ggmlPaths := []string{discover.LibOllamaPath}
if len(compatible) > 0 {
c := compatible[0]
if libpath, ok := libs[c]; ok {
slog.Debug("adding gpu library", "path", libpath)
libraryPaths = append(libraryPaths, libpath)
ggmlPaths = append(ggmlPaths, libpath)
}
}
@@ -369,6 +371,8 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
s.cmd.Stderr = s.status
s.cmd.SysProcAttr = LlamaServerSysProcAttr
s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ggmlPaths, string(filepath.ListSeparator)))
envWorkarounds := [][2]string{}
for _, gpu := range gpus {
envWorkarounds = append(envWorkarounds, gpu.EnvWorkarounds...)
@@ -406,7 +410,8 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
if envconfig.Debug() {
filteredEnv := []string{}
for _, ev := range s.cmd.Env {
if strings.HasPrefix(ev, "CUDA_") ||
if strings.HasPrefix(ev, "OLLAMA_") ||
strings.HasPrefix(ev, "CUDA_") ||
strings.HasPrefix(ev, "ROCR_") ||
strings.HasPrefix(ev, "ROCM_") ||
strings.HasPrefix(ev, "HIP_") ||

View File

@@ -133,6 +133,7 @@ type Tensor interface {
Mul(ctx Context, t2 Tensor) Tensor
Mulmat(ctx Context, t2 Tensor) Tensor
MulmatFullPrec(ctx Context, t2 Tensor) Tensor
MulmatID(ctx Context, t2, ids Tensor) Tensor
Softmax(ctx Context) Tensor
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
@@ -150,6 +151,7 @@ type Tensor interface {
Tanh(ctx Context) Tensor
GELU(ctx Context) Tensor
SILU(ctx Context) Tensor
Sigmoid(ctx Context) Tensor
Reshape(ctx Context, shape ...int) Tensor
View(ctx Context, offset int, shape ...int) Tensor
@@ -168,6 +170,8 @@ type Tensor interface {
Rows(ctx Context, t2 Tensor) Tensor
Copy(ctx Context, t2 Tensor) Tensor
Duplicate(ctx Context) Tensor
TopK(ctx Context, k int) Tensor
}
// ScaledDotProductAttention implements a fused attention

View File

@@ -884,17 +884,32 @@ func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) MulmatID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_mul_mat_id(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, ids.(*Tensor).t),
}
}
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
if b != nil {
tt = tt.Add(ctx, b)
tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))
if w != nil {
tt = C.ggml_mul(ctx.(*Context).ctx, tt, w.(*Tensor).t)
if b != nil {
tt = C.ggml_add(ctx.(*Context).ctx, tt, b.(*Tensor).t)
}
}
return tt
return &Tensor{b: t.b, t: tt}
}
func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
tt := C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))
if w != nil {
tt = C.ggml_mul(ctx.(*Context).ctx, tt, w.(*Tensor).t)
}
return &Tensor{b: t.b, t: tt}
}
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
@@ -995,6 +1010,13 @@ func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
}
}
func (t *Tensor) Sigmoid(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_sigmoid_inplace(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
if len(shape) != 4 {
panic("expected 4 dimensions")
@@ -1158,3 +1180,10 @@ func (t *Tensor) Duplicate(ctx ml.Context) ml.Tensor {
t: C.ggml_dup(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) TopK(ctx ml.Context, k int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_top_k(ctx.(*Context).ctx, t.t, C.int(k)),
}
}

View File

@@ -57,26 +57,20 @@ var OnceLoad = sync.OnceFunc(func() {
exe = "."
}
// PATH, LD_LIBRARY_PATH, and DYLD_LIBRARY_PATH are often
// set by the parent process, however, use a default value
// if the environment variable is not set.
var name, value string
var value string
switch runtime.GOOS {
case "darwin":
// On macOS, DYLD_LIBRARY_PATH is often not set, so
// we use the directory of the executable as the default.
name = "DYLD_LIBRARY_PATH"
value = filepath.Dir(exe)
case "windows":
name = "PATH"
value = filepath.Join(filepath.Dir(exe), "lib", "ollama")
default:
name = "LD_LIBRARY_PATH"
value = filepath.Join(filepath.Dir(exe), "..", "lib", "ollama")
}
paths, ok := os.LookupEnv(name)
// Avoid potentially loading incompatible GGML libraries
paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH")
if !ok {
slog.Debug("OLLAMA_LIBRARY_PATH not set, falling back to default", "search", value)
paths = value
}

View File

@@ -42,7 +42,7 @@ func New(c fs.Config) (model.Model, error) {
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Uints("tokenizer.ggml.token_type"),
Types: c.Ints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
},

View File

@@ -59,7 +59,7 @@ func New(c fs.Config) (model.Model, error) {
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Uints("tokenizer.ggml.token_type"),
Types: c.Ints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
EOS: int32(1),

View File

@@ -49,7 +49,7 @@ func newTextModel(c fs.Config) *TextModel {
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Uints("tokenizer.ggml.token_type"),
Types: c.Ints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
},

View File

@@ -41,7 +41,7 @@ func New(c fs.Config) (model.Model, error) {
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),

View File

@@ -0,0 +1,189 @@
package llama4
import (
"bytes"
"image"
"slices"
"sync"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.BytePairEncoding
ImageProcessor
*VisionModel `gguf:"v,vision"`
*Projector `gguf:"mm"`
*TextModel
}
type Projector struct {
Linear1 *nn.Linear `gguf:"linear_1"`
}
func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
return p.Linear1.Forward(ctx, visionOutputs)
}
func New(c fs.Config) (model.Model, error) {
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer",
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
},
),
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),
}
m.Cache = kvcache.NewWrapperCache(
kvcache.NewChunkedAttentionCache(int32(c.Uint("attention.chunk_size", 8192)), m.Shift),
kvcache.NewCausalCache(m.Shift),
)
return &m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
if len(m.VisionModel.Layers) < 1 {
return nil, model.ErrNoVisionModel
}
img, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
pixelsLocal, pixelsGlobal, size, err := m.ProcessImage(img)
if err != nil {
return nil, err
}
tilesLocal, err := ctx.Input().FromFloatSlice(pixelsLocal, size.X, size.Y, m.numChannels)
if err != nil {
return nil, err
}
ratioW, ratioH := size.X/m.imageSize, size.Y/m.imageSize
tilesLocal = tilesLocal.Reshape(ctx, size.X/ratioW, ratioW, size.Y, m.numChannels).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
tilesLocal = tilesLocal.Reshape(ctx, size.X/ratioW*size.Y/ratioH, ratioH, ratioW, m.numChannels).Permute(ctx, 0, 3, 2, 1).Contiguous(ctx)
tilesLocal = tilesLocal.Reshape(ctx, size.X/ratioW, size.Y/ratioH, m.numChannels, ratioH*ratioW)
pixelValues := tilesLocal
if len(pixelsGlobal) > 0 {
tilesGlobal, err := ctx.Input().FromFloatSlice(pixelsGlobal, m.imageSize, m.imageSize, m.numChannels)
if err != nil {
return nil, err
}
pixelValues = pixelValues.Concat(ctx, tilesGlobal, 3)
}
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0), visionOutputs.Dim(1)*visionOutputs.Dim(2)*visionOutputs.Dim(3))
projectedOutputs := m.Projector.Forward(ctx, visionOutputs)
return &chunks{Model: m, Tensor: projectedOutputs, aspectRatio: image.Point{ratioW, ratioH}}, nil
}
type chunks struct {
*Model
ml.Tensor
aspectRatio image.Point
dataOnce sync.Once
data []float32
}
type chunk struct {
*chunks
s, n int
}
func (r *chunk) floats() []float32 {
r.dataOnce.Do(func() {
temp := r.Backend().NewContext()
defer temp.Close()
temp.Forward(r.Tensor).Compute(r.Tensor)
r.data = r.Floats()
})
return r.data[r.s*r.Dim(0) : (r.s+r.n)*r.Dim(0)]
}
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
for _, inp := range inputs {
if inp.Multimodal == nil {
result = append(result, inp)
continue
}
t := inp.Multimodal.(*chunks)
var imageInputs []input.Input
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_start|>
var offset int
patchesPerChunk := t.Dim(1)
if t.aspectRatio.Y*t.aspectRatio.X > 1 {
patchesPerChunk = t.Dim(1) / (t.aspectRatio.X*t.aspectRatio.Y + 1)
for range t.aspectRatio.Y {
for x := range t.aspectRatio.X {
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: &chunk{t, offset, patchesPerChunk}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
if x < t.aspectRatio.X-1 {
imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|>
}
offset += patchesPerChunk
}
imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|>
}
}
imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|>
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: &chunk{t, offset, patchesPerChunk}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|>
result = append(result, imageInputs...)
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
}
func init() {
model.Register("llama4", New)
}

View File

@@ -0,0 +1,259 @@
package llama4
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model/input"
)
type TextAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
RopeFactors ml.Tensor `gguf:"rope_factors"`
}
func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attentionScales ml.Tensor, cache kvcache.Cache, useRope bool, opts *TextOptions) ml.Tensor {
batchSize, headDim := hiddenStates.Dim(1), cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
if useRope {
query = query.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale)
key = key.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale)
}
if opts.useQKNorm {
query = query.RMSNorm(ctx, nil, opts.eps)
key = key.RMSNorm(ctx, nil, opts.eps)
}
if attentionScales != nil && !useRope {
query = query.Mul(ctx, attentionScales)
}
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
}
type TextMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type TextExperts struct {
Gate ml.Tensor `gguf:"ffn_gate_exps.weight"`
Up ml.Tensor `gguf:"ffn_up_exps.weight"`
Down ml.Tensor `gguf:"ffn_down_exps.weight"`
}
func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor {
experts := routerLogits.TopK(ctx, opts.numExpertsUsed)
scores := routerLogits.Sigmoid(ctx).Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, experts)
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed)
hiddenStates = hiddenStates.Mul(ctx, scores)
upStates := e.Up.MulmatID(ctx, hiddenStates, experts)
gateStates := e.Gate.MulmatID(ctx, hiddenStates, experts)
downStates := e.Down.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
nextStates.Add(ctx, downStates.View(ctx, i*downStates.Stride(1), hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2)))
}
return nextStates
}
// TextSharedExpert is TextMLP with different tensor names
type TextSharedExpert struct {
Gate *nn.Linear `gguf:"ffn_gate_shexp"`
Up *nn.Linear `gguf:"ffn_up_shexp"`
Down *nn.Linear `gguf:"ffn_down_shexp"`
}
func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type TextMOE struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Experts *TextExperts
SharedExpert *TextSharedExpert
}
func (moe *TextMOE) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize)
routerLogits := moe.Router.Forward(ctx, hiddenStates)
sharedStates := moe.SharedExpert.Forward(ctx, hiddenStates, opts)
routedStates := moe.Experts.Forward(ctx, hiddenStates, routerLogits, opts)
return sharedStates.Add(ctx, routedStates)
}
type TextFeedForward interface {
Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor
}
type TextLayer struct {
AttentionNorm *nn.LayerNorm `gguf:"attn_norm"`
Attention *TextAttention
FFNNorm *nn.LayerNorm `gguf:"ffn_norm"`
FeedForward TextFeedForward
}
func (d *TextLayer) Forward(ctx ml.Context, hiddenStates, positions, attentionScales, outputs ml.Tensor, cache kvcache.Cache, useRope bool, opts *TextOptions) ml.Tensor {
residual := hiddenStates
// self attention
hiddenStates = d.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = d.Attention.Forward(ctx, hiddenStates, positions, attentionScales, cache, useRope, opts)
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = d.FFNNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = d.FeedForward.Forward(ctx, hiddenStates, opts)
return residual.Add(ctx, hiddenStates)
}
type TextOptions struct {
hiddenSize int
numHeads, numKVHeads, headDim int
numExperts, numExpertsUsed int
ropeDim int
ropeBase, ropeScale float32
eps float32
interleaveLayerStep int
noRopeInterval int
useQKNorm bool
attentionTemperatureTuning bool
attentionScale float64
attentionFloorScale float64
}
type TextModel struct {
Layers []TextLayer `gguf:"blk"`
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.LayerNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*TextOptions
}
func newTextModel(c fs.Config) *TextModel {
layers := make([]TextLayer, c.Uint("block_count"))
interleaveLayerStep := c.Uint("interleave_moe_layer_step", 1)
for i := range layers {
if (i+1)%int(interleaveLayerStep) == 0 {
layers[i] = TextLayer{FeedForward: &TextMOE{}}
} else {
layers[i] = TextLayer{FeedForward: &TextMLP{}}
}
}
return &TextModel{
Layers: layers,
TextOptions: &TextOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.head_dim", 128)),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
ropeDim: int(c.Uint("rope.dimension_count")),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
eps: c.Float("attention.layer_norm_rms_epsilon"),
interleaveLayerStep: int(c.Uint("interleave_moe_layer_step", 1)),
noRopeInterval: int(c.Uint("no_rope_interval", 4)),
useQKNorm: c.Bool("use_qk_norm", true),
attentionTemperatureTuning: c.Bool("attention.temperature_tuning", true),
attentionScale: float64(c.Float("attention.scale", 0.1)),
attentionFloorScale: float64(c.Float("attention.floor_scale", 8192)),
},
}
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
for _, mi := range batch.Multimodal {
f32s := mi.Multimodal.(*chunk).floats()
img, err := ctx.Input().FromFloatSlice(f32s, len(f32s)/m.hiddenSize, m.hiddenSize)
if err != nil {
panic(err)
}
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
}
var attentionScales ml.Tensor
if m.attentionTemperatureTuning {
scales := make([]float32, len(batch.Positions))
for i, p := range batch.Positions {
scales[i] = float32(math.Log(math.Floor(((float64(p)+1.0)/float64(m.attentionFloorScale))+1.0))*m.attentionScale + 1.0)
}
var err error
attentionScales, err = ctx.Input().FromFloatSlice(scales, 1, 1, len(scales))
if err != nil {
panic(err)
}
}
for i, layer := range m.Layers {
cache.SetLayer(i)
wc := cache.(*kvcache.WrapperCache)
wc.SetLayerType(1)
useChunkedAttention := (i+1)%m.noRopeInterval != 0
if useChunkedAttention {
wc.SetLayerType(0)
}
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, attentionScales, lastLayerOutputs, cache, useChunkedAttention, m.TextOptions)
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates)
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Layers[layer].Attention.RopeFactors, uint32(0), uint32(m.ropeDim), m.ropeBase, m.ropeScale), nil
}

View File

@@ -0,0 +1,256 @@
package llama4
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
type VisionAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
// applyVisionRotaryEmbedding applies 2D rotary embedding to the input tensor.
// This is equivalent to the Pytorch implmentation using half rotations:
//
// cos, sin = torch.cos(freqs), torch.sin(freqs)
// cos = cos.unsqueeze(-1)
// sin = sin.unsqueeze(-1)
// t = t.reshape(*t.shape[:-1], -1, 2)
// t_out = (t * cos) + (_rotate_half(t) * sin)
// t_out = t_out.flatten(3)
//
// Which is equivalent to the Pytorch implementation using complex numbers:
//
// t_ = torch.view_as_complex(t.float().reshape(*t.shape[:-1], -1, 2))
// freqs_ci = reshape_for_broadcast(freqs_ci=freq_cis, t=t_) # freqs_ci[:,:,None,:]
// freqs_ci = freqs_ci.to(t_.device)
// t_out = torch.view_as_real(t_ * freqs_ci).flatten(3)
//
// Due to the 1) the dimensional and 2) the datatype limitations of current backends,
// we need to use a different approach to achieve the same result.
func applyVisionRotaryEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
width, height, channels, tiles := t.Dim(0), t.Dim(1), t.Dim(2), t.Dim(3)
t = t.Reshape(ctx, 2, t.Dim(0)/2, t.Dim(1)*t.Dim(2)*t.Dim(3))
// t1 = t[..., 0::2]
t1 := t.View(ctx, 0, 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx)
t1 = t1.Reshape(ctx, width/2, height, channels, tiles)
// t2 = t[..., 1::2]
t2 := t.View(ctx, t.Stride(0), 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx)
t2 = t2.Reshape(ctx, width/2, height, channels, tiles)
// cos_out = torch.stack((t1 * cos, t2 * cos), dim=-1)
cosOut := t1.Mul(ctx, cos).Concat(ctx, t2.Mul(ctx, cos), 0)
cosOut = cosOut.Reshape(ctx, cosOut.Dim(0)/2, 2, cosOut.Dim(1)*cosOut.Dim(2)*cosOut.Dim(3))
cosOut = cosOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
cosOut = cosOut.Reshape(ctx, width, height, channels, tiles)
// sin_out = torch.stack((-t2 * sin, t1 * sin), dim=-1)
sinOut := t2.Neg(ctx).Mul(ctx, sin).Concat(ctx, t1.Mul(ctx, sin), 0)
sinOut = sinOut.Reshape(ctx, sinOut.Dim(0)/2, 2, sinOut.Dim(1)*sinOut.Dim(2)*sinOut.Dim(3))
sinOut = sinOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
sinOut = sinOut.Reshape(ctx, width, height, channels, tiles)
return cosOut.Add(ctx, sinOut)
}
func (sa *VisionAttention) Forward(ctx ml.Context, hiddenState, cos, sin ml.Tensor, opts *VisionOptions) ml.Tensor {
headDim := opts.hiddenSize / opts.numHeads
query := sa.Query.Forward(ctx, hiddenState)
key := sa.Key.Forward(ctx, hiddenState)
value := sa.Value.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), query.Dim(2))
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), key.Dim(2))
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), value.Dim(2))
query = applyVisionRotaryEmbedding(ctx, query, cos, sin)
key = applyVisionRotaryEmbedding(ctx, key, cos, sin)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), attention.Dim(3))
return sa.Output.Forward(ctx, attention)
}
type VisionMLP struct {
FC1 *nn.Linear `gguf:"fc1"`
FC2 *nn.Linear `gguf:"fc2"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor {
hiddenStates = mlp.FC1.Forward(ctx, hiddenStates).GELU(ctx)
hiddenStates = mlp.FC2.Forward(ctx, hiddenStates)
return hiddenStates
}
type VisionLayer struct {
InputLayerNorm *nn.LayerNorm `gguf:"attn_norm"`
*VisionAttention
PostAttentionNorm *nn.LayerNorm `gguf:"ffn_norm"`
*VisionMLP `gguf:"mlp"`
}
func (e *VisionLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionOptions) ml.Tensor {
residual := hiddenStates
// self attention
hiddenStates = e.InputLayerNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.VisionAttention.Forward(ctx, hiddenStates, cos, sin, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
// MLP
residual = hiddenStates
hiddenStates = e.PostAttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.VisionMLP.Forward(ctx, hiddenStates, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
return hiddenStates
}
type VisionAdapter struct {
FC1 *nn.Linear `gguf:"mlp.fc1"`
FC2 *nn.Linear `gguf:"mlp.fc2"`
}
func (a *VisionAdapter) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor {
patches := hiddenStates.Dim(1)
patchSize := int(math.Sqrt(float64(patches)))
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), patchSize, patchSize, hiddenStates.Dim(2))
channels, width, height, tiles := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2), hiddenStates.Dim(3)
channels, width = int(float32(channels)/opts.pixelShuffleRatio), int(float32(width)*opts.pixelShuffleRatio)
hiddenStates = hiddenStates.Reshape(ctx, channels, width, height, tiles)
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
channels, height = int(float32(channels)/opts.pixelShuffleRatio), int(float32(height)*opts.pixelShuffleRatio)
hiddenStates = hiddenStates.Reshape(ctx, channels, width, height, tiles)
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
hiddenStates = hiddenStates.Reshape(ctx, channels, width*height, tiles)
hiddenStates = a.FC1.Forward(ctx, hiddenStates).GELU(ctx)
hiddenStates = a.FC2.Forward(ctx, hiddenStates).GELU(ctx)
return hiddenStates
}
type VisionOptions struct {
hiddenSize, numHeads int
imageSize, patchSize int
ropeTheta float32
eps float32
pixelShuffleRatio float32
}
type PatchEmbedding struct {
*nn.Linear
}
func (p *PatchEmbedding) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor {
kernel := ctx.Input().Empty(ml.DTypeF32, opts.patchSize, opts.patchSize, hiddenStates.Dim(2))
hiddenStates = kernel.IM2Col(ctx, hiddenStates, opts.patchSize, opts.patchSize, 0, 0, 1, 1)
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), hiddenStates.Dim(1)*hiddenStates.Dim(2), hiddenStates.Dim(3))
return p.Linear.Forward(ctx, hiddenStates)
}
type VisionModel struct {
Layers []VisionLayer `gguf:"blk"`
*PatchEmbedding `gguf:"patch_embedding"`
ClassEmbedding ml.Tensor `gguf:"class_embedding"`
PositionalEmbedding ml.Tensor `gguf:"positional_embedding_vlm"`
LayerNormPre *nn.LayerNorm `gguf:"layernorm_pre"`
LayerNormPost *nn.LayerNorm `gguf:"layernorm_post"`
*VisionAdapter `gguf:"vision_adapter"`
*VisionOptions
}
func newVisionModel(c fs.Config) *VisionModel {
return &VisionModel{
Layers: make([]VisionLayer, c.Uint("vision.block_count")),
VisionOptions: &VisionOptions{
hiddenSize: int(c.Uint("vision.embedding_length")),
numHeads: int(c.Uint("vision.attention.head_count")),
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),
ropeTheta: float32(c.Float("vision.rope.freq_base")),
eps: c.Float("vision.layer_norm_epsilon"),
pixelShuffleRatio: float32(c.Float("vision.pixel_shuffle_ratio")),
},
}
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionOptions)
hiddenStates = hiddenStates.Concat(ctx, m.ClassEmbedding.Repeat(ctx, 2, hiddenStates.Dim(2)), 1)
hiddenStates = hiddenStates.Add(ctx, m.PositionalEmbedding)
hiddenStates = m.LayerNormPre.Forward(ctx, hiddenStates, m.eps)
cos, sin := m.rotaryEmbedding(ctx)
for _, layer := range m.Layers {
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionOptions)
}
hiddenStates = m.LayerNormPost.Forward(ctx, hiddenStates, m.eps)
hiddenStates = hiddenStates.Unpad(ctx, 0, 1, 0, 0)
hiddenStates = m.VisionAdapter.Forward(ctx, hiddenStates, m.VisionOptions)
return hiddenStates
}
// floorDiv is a helper function to perform floor division. This mimics PyTorch's div(round_mode='floor') function
// which in turn mimics Python's // operator.
func floorDiv[T int | int16 | int32 | int64 | uint | uint16 | uint32 | uint64](a, b T) T {
if b == 0 {
panic("division by zero")
}
if (a >= 0 && b > 0) || (a <= 0 && b < 0) || a%b == 0 {
return a / b
}
return a/b - 1
}
func (m *VisionModel) rotaryEmbedding(ctx ml.Context) (ml.Tensor, ml.Tensor) {
patchesPerSide := m.imageSize / m.patchSize
numPatches := patchesPerSide*patchesPerSide + 1
headDim := m.hiddenSize / m.numHeads
freqDim := headDim / 2
freqs := make([]float32, numPatches*freqDim)
for i := range numPatches - 1 {
for j := 0; j < freqDim; j += 2 {
positionX := i*freqDim/2 + j/2
positionY := (i+numPatches)*freqDim/2 + j/2
ropeFreq := math.Pow(float64(m.ropeTheta), float64(j)*2/float64(headDim))
freqs[positionX] = float32(float64(1+i-floorDiv(i, patchesPerSide)*patchesPerSide) / ropeFreq)
freqs[positionY] = float32(float64(1+floorDiv(i, patchesPerSide)) / ropeFreq)
}
}
ropeFreqs, err := ctx.Input().FromFloatSlice(freqs, freqDim/2, numPatches, 2)
if err != nil {
panic(err)
}
ropeFreqs = ropeFreqs.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
ropeFreqs = ropeFreqs.Reshape(ctx, freqDim, 1, numPatches)
return ropeFreqs.Cos(ctx), ropeFreqs.Sin(ctx)
}

View File

@@ -0,0 +1,167 @@
package llama4
import (
"cmp"
"image"
"math"
"slices"
"sort"
"golang.org/x/image/draw"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/model/imageproc"
)
type ImageProcessor struct {
imageSize, patchSize, numChannels, maxUpscalingSize int
}
func newImageProcessor(c fs.Config) ImageProcessor {
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),
numChannels: int(c.Uint("vision.num_channels", 3)),
maxUpscalingSize: int(c.Uint("vision.max_upscaling_size", 448)),
}
}
func factors(n int) []int {
var result []int
seen := make(map[int]bool)
for i := 1; i <= n/2; i++ {
if n%i == 0 && !seen[i] {
result = append(result, i)
seen[i] = true
}
}
result = append(result, n)
sort.Ints(result)
return result
}
func (p ImageProcessor) supportedResolutions() []image.Point {
var resolutions []image.Point
aspectMap := make(map[float64][]image.Point)
for i := p.patchSize; i >= 1; i-- {
for _, f := range factors(i) {
x := f
y := i / f
k := float64(y) / float64(x)
aspectMap[k] = append(aspectMap[k], image.Point{x, y})
}
}
for _, v := range aspectMap {
for _, i := range v {
resolutions = append(resolutions, image.Point{i.X * p.imageSize, i.Y * p.imageSize})
}
}
return resolutions
}
func (p ImageProcessor) bestResolution(img image.Point, possibleResolutions []image.Point, resizeToMaxCanvas bool) image.Point {
w, h := img.X, img.Y
scales := make([]float64, len(possibleResolutions))
for i, res := range possibleResolutions {
scaleW := float64(res.X) / float64(w)
scaleH := float64(res.Y) / float64(h)
scale := math.Min(scaleW, scaleH)
scales[i] = scale
}
minAboveOne := func(scales []float64) (float64, bool) {
min := math.MaxFloat64
found := false
for _, s := range scales {
if s >= 1.0 && s < min {
min = s
found = true
}
}
return min, found
}
bestScale, ok := minAboveOne(scales)
if resizeToMaxCanvas || !ok {
bestScale = slices.Max(scales)
}
var bestOptions []image.Point
for i, scale := range scales {
if math.Abs(scale-bestScale) < 1e-6 {
bestOptions = append(bestOptions, possibleResolutions[i])
}
}
var chosenResolution image.Point
if len(bestOptions) > 1 {
chosenResolution = slices.MinFunc(bestOptions, func(a, b image.Point) int {
return cmp.Compare(a.X*a.Y, b.X*b.Y)
})
} else {
chosenResolution = bestOptions[0]
}
return chosenResolution
}
func (p ImageProcessor) maxResolution(imageRes, targetRes image.Point) image.Point {
scaleW := float64(targetRes.X) / float64(imageRes.X)
scaleH := float64(targetRes.Y) / float64(imageRes.Y)
var newRes image.Point
if scaleW < scaleH {
newRes = image.Point{
targetRes.X,
int(math.Min(math.Floor(float64(imageRes.Y)*scaleW), float64(targetRes.Y))),
}
} else {
newRes = image.Point{
int(math.Min(math.Floor(float64(imageRes.X)*scaleH), float64(targetRes.X))),
targetRes.Y,
}
}
return newRes
}
func (p ImageProcessor) pad(src image.Image, outputSize image.Point) image.Image {
dst := image.NewRGBA(image.Rect(0, 0, outputSize.X, outputSize.Y))
draw.Draw(dst, src.Bounds(), src, image.Point{}, draw.Over)
return dst
}
func (p ImageProcessor) ProcessImage(img image.Image) (pixelsLocal, pixelsGlobal []float32, targetSize image.Point, _ error) {
img = imageproc.Composite(img)
targetSize = p.bestResolution(img.Bounds().Max, p.supportedResolutions(), false)
targetSizeWithoutDistortion := targetSize
if p.maxUpscalingSize > 0 {
targetSizeWithoutDistortion = p.maxResolution(img.Bounds().Max, targetSize)
targetSizeWithoutDistortion.X = min(max(img.Bounds().Max.X, p.maxUpscalingSize), targetSize.X)
targetSizeWithoutDistortion.Y = min(max(img.Bounds().Max.Y, p.maxUpscalingSize), targetSize.Y)
}
newSizeWithoutDistortion := p.maxResolution(img.Bounds().Max, targetSizeWithoutDistortion)
padded := p.pad(imageproc.Resize(img, newSizeWithoutDistortion, imageproc.ResizeBilinear), targetSize)
pixelsLocal = imageproc.Normalize(padded, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD, true, true)
if targetSize.X/p.imageSize*targetSize.Y/p.imageSize > 1 {
padded := imageproc.Resize(img, image.Point{p.imageSize, p.imageSize}, imageproc.ResizeBilinear)
pixelsGlobal = imageproc.Normalize(padded, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD, true, true)
}
return pixelsLocal, pixelsGlobal, targetSize, nil
}

View File

@@ -0,0 +1,300 @@
package llama4
import (
"cmp"
"image"
"image/color"
"reflect"
"slices"
"testing"
gocmp "github.com/google/go-cmp/cmp"
)
func TestFactors(t *testing.T) {
tests := []struct {
name string
input int
expected []int
}{
{
name: "factors of 1",
input: 1,
expected: []int{1},
},
{
name: "factors of 2",
input: 2,
expected: []int{1, 2},
},
{
name: "factors of 6",
input: 6,
expected: []int{1, 2, 3, 6},
},
{
name: "factors of 28",
input: 28,
expected: []int{1, 2, 4, 7, 14, 28},
},
{
name: "factors of 49",
input: 49,
expected: []int{1, 7, 49},
},
{
name: "factors of 97 (prime)",
input: 97,
expected: []int{1, 97},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := factors(tt.input)
if !reflect.DeepEqual(actual, tt.expected) {
t.Errorf("factors(%d) = %v; want %v", tt.input, actual, tt.expected)
}
})
}
}
func TestSupportedResolutions(t *testing.T) {
expectedResolutions := []image.Point{
{X: 3360, Y: 336},
{X: 672, Y: 2688},
{X: 336, Y: 1344},
{X: 336, Y: 4032},
{X: 1008, Y: 1344},
{X: 1344, Y: 1008},
{X: 336, Y: 1680},
{X: 1680, Y: 336},
{X: 336, Y: 5040},
{X: 4032, Y: 336},
{X: 2352, Y: 336},
{X: 2688, Y: 672},
{X: 1344, Y: 336},
{X: 5376, Y: 336},
{X: 2352, Y: 672},
{X: 672, Y: 1008},
{X: 1008, Y: 672},
{X: 336, Y: 5376},
{X: 1680, Y: 1008},
{X: 5040, Y: 336},
{X: 336, Y: 3024},
{X: 3024, Y: 336},
{X: 336, Y: 2688},
{X: 672, Y: 1344},
{X: 336, Y: 672},
{X: 336, Y: 2352},
{X: 2016, Y: 672},
{X: 1008, Y: 336},
{X: 336, Y: 3360},
{X: 336, Y: 4368},
{X: 1008, Y: 1680},
{X: 336, Y: 4704},
{X: 4704, Y: 336},
{X: 1344, Y: 672},
{X: 672, Y: 336},
{X: 2688, Y: 336},
{X: 3696, Y: 336},
{X: 2016, Y: 336},
{X: 1344, Y: 1344},
{X: 1008, Y: 1008},
{X: 672, Y: 672},
{X: 336, Y: 336},
{X: 4368, Y: 336},
{X: 672, Y: 2016},
{X: 336, Y: 1008},
{X: 336, Y: 3696},
{X: 672, Y: 1680},
{X: 1680, Y: 672},
{X: 336, Y: 2016},
{X: 672, Y: 2352},
}
sortResolutionFunc := func(a, b image.Point) int {
return cmp.Or(cmp.Compare(a.X, b.X), cmp.Compare(a.Y, b.Y))
}
slices.SortStableFunc(expectedResolutions, sortResolutionFunc)
imgProc := ImageProcessor{
imageSize: 336,
patchSize: 16,
numChannels: 3,
maxUpscalingSize: 448,
}
actualResolutions := imgProc.supportedResolutions()
slices.SortStableFunc(actualResolutions, sortResolutionFunc)
if diff := gocmp.Diff(expectedResolutions, actualResolutions); diff != "" {
t.Errorf("supportedResolutions() mismatch (-want +got):\n%s", diff)
}
}
func TestBestResolution(t *testing.T) {
tests := []struct {
name string
size image.Point
resolutions []image.Point
max bool
expected image.Point
}{
{
"normal",
image.Point{800, 600},
[]image.Point{
{300, 200},
{640, 480},
{800, 600},
{1024, 768},
{1600, 1200},
},
false,
image.Point{800, 600},
},
{
"max",
image.Point{800, 600},
[]image.Point{
{300, 200},
{640, 480},
{800, 600},
{1024, 768},
{1600, 1200},
},
true,
image.Point{1600, 1200},
},
{
"mid",
image.Point{1000, 700},
[]image.Point{
{300, 200},
{640, 480},
{800, 600},
{1024, 768},
{1600, 1200},
},
false,
image.Point{1024, 768},
},
{
"smol",
image.Point{100, 100},
[]image.Point{
{300, 200},
{640, 480},
{800, 600},
{1024, 768},
{1600, 1200},
},
false,
image.Point{300, 200},
},
{
"huge",
image.Point{10000, 10000},
[]image.Point{
{300, 200},
{640, 480},
{800, 600},
{1024, 768},
{1600, 1200},
},
false,
image.Point{1600, 1200},
},
}
p := ImageProcessor{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := p.bestResolution(tt.size, tt.resolutions, tt.max)
if diff := gocmp.Diff(tt.expected, actual); diff != "" {
t.Errorf("best resolution mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestMaxResolution(t *testing.T) {
tests := []struct {
name string
origRes image.Point
targetRes image.Point
expected image.Point
}{
{
"normal",
image.Point{800, 600},
image.Point{800, 600},
image.Point{800, 600},
},
{
"skew",
image.Point{800, 600},
image.Point{1100, 700},
image.Point{933, 700},
},
}
p := ImageProcessor{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := p.maxResolution(tt.origRes, tt.targetRes)
if !reflect.DeepEqual(actual, tt.expected) {
t.Errorf("max resolution; got %v want %v", actual, tt.expected)
}
})
}
}
func TestProcessImage(t *testing.T) {
imgProc := ImageProcessor{
imageSize: 336,
patchSize: 16,
numChannels: 3,
maxUpscalingSize: 448,
}
generateImage := func(seed int) image.Image {
width, height := 20, 10
img := image.NewRGBA(image.Rect(0, 0, width, height))
for x := range width {
// Use the seed to vary color generation
r := uint8((seed + x*11) % 256)
g := uint8((seed + x*17) % 256)
b := uint8((seed + x*23) % 256)
c := color.RGBA{R: r, G: g, B: b, A: 255}
for y := range height {
img.Set(x, y, c)
}
}
return img
}
pixelsLocal, pixelsGlobal, targetSize, err := imgProc.ProcessImage(generateImage(12))
if err != nil {
t.Error(err)
}
if n := len(pixelsLocal); n != 336*336*3 {
t.Errorf("unexpected size of f32s: %d", n)
}
if n := len(pixelsGlobal); n > 0 {
t.Errorf("unexpected size of f32s: %d", n)
}
if !targetSize.Eq(image.Point{336, 336}) {
t.Errorf("unexpected target size: %v", targetSize)
}
}

View File

@@ -152,7 +152,7 @@ func NewTextModel(c fs.Config) (*TextModel, error) {
c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id", 1)),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),

View File

@@ -43,7 +43,7 @@ func New(c fs.Config) (model.Model, error) {
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),

View File

@@ -177,7 +177,7 @@ type TextDecoder struct {
func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
for i, layer := range d.Layers {
layerType := selfAttentionLayer
if slices.Contains(opts.crossAttentionLayers, uint32(i)) {
if slices.Contains(opts.crossAttentionLayers, int32(i)) {
layerType = crossAttentionLayer
}
@@ -202,7 +202,7 @@ type TextModelOptions struct {
eps, ropeBase, ropeScale float32
ropeDim uint32
crossAttentionLayers []uint32
crossAttentionLayers []int32
}
type TextModel struct {
@@ -225,7 +225,7 @@ func newTextModel(c fs.Config) *TextModel {
var decoderLayers []TextDecoderLayer
for i := range c.Uint("block_count") {
var textDecoderLayer TextDecoderLayer
if slices.Contains(c.Uints("attention.cross_attention_layers"), i) {
if slices.Contains(c.Ints("attention.cross_attention_layers"), int32(i)) {
textDecoderLayer = &TextCrossAttentionDecoderLayer{}
} else {
textDecoderLayer = &TextSelfAttentionDecoderLayer{}
@@ -244,7 +244,7 @@ func newTextModel(c fs.Config) *TextModel {
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count"),
crossAttentionLayers: c.Uints("attention.cross_attention_layers"),
crossAttentionLayers: c.Ints("attention.cross_attention_layers"),
},
}
}

View File

@@ -96,10 +96,10 @@ type VisionEncoder struct {
Layers []VisionEncoderLayer
}
func (e *VisionEncoder) Forward(ctx ml.Context, hiddenState ml.Tensor, intermediateLayersIndices []uint32, opts *VisionModelOptions) (ml.Tensor, []ml.Tensor) {
func (e *VisionEncoder) Forward(ctx ml.Context, hiddenState ml.Tensor, intermediateLayersIndices []int32, opts *VisionModelOptions) (ml.Tensor, []ml.Tensor) {
var intermediateHiddenStates []ml.Tensor
for i, layer := range e.Layers {
if slices.Contains(intermediateLayersIndices, uint32(i)) {
if slices.Contains(intermediateLayersIndices, int32(i)) {
intermediateHiddenStates = append(intermediateHiddenStates, hiddenState.Reshape(ctx, append([]int{1}, hiddenState.Shape()...)...))
}
@@ -154,7 +154,7 @@ type VisionModelOptions struct {
imageSize, patchSize int
eps float32
intermediateLayersIndices []uint32
intermediateLayersIndices []int32
}
type VisionModel struct {
@@ -229,7 +229,7 @@ func newVisionModel(c fs.Config) *VisionModel {
eps: c.Float("vision.attention.layer_norm_epsilon"),
intermediateLayersIndices: c.Uints("vision.intermediate_layers_indices"),
intermediateLayersIndices: c.Ints("vision.intermediate_layers_indices"),
},
}
}

View File

@@ -4,6 +4,7 @@ import (
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/llama4"
_ "github.com/ollama/ollama/model/models/mistral3"
_ "github.com/ollama/ollama/model/models/mllama"
)

View File

@@ -37,7 +37,7 @@ type TextProcessor interface {
type Vocabulary struct {
Values []string
Types []uint32
Types []int32
Scores []float32
Merges []string

View File

@@ -35,9 +35,9 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
sentencepiece.ModelProto_SentencePiece_CONTROL,
sentencepiece.ModelProto_SentencePiece_UNUSED,
sentencepiece.ModelProto_SentencePiece_BYTE:
v.Types = append(v.Types, uint32(t))
v.Types = append(v.Types, int32(t))
default:
tt := uint32(sentencepiece.ModelProto_SentencePiece_NORMAL)
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
// todo parse the special tokens file
// - this will roundtrip correctly but the <start_of_turn> and
// <end_of_turn> tokens aren't processed
@@ -124,7 +124,7 @@ func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
"<0xC3>",
"<0xA3>",
},
Types: []uint32{
Types: []int32{
TOKEN_TYPE_NORMAL,
TOKEN_TYPE_BYTE,
TOKEN_TYPE_BYTE,

View File

@@ -28,7 +28,7 @@ func llama(t testing.TB) BytePairEncoding {
t.Fatal(err)
}
types := make([]uint32, len(vocab))
types := make([]int32, len(vocab))
tokens := make([]string, len(vocab))
for token, id := range vocab {
tokens[id] = token

View File

@@ -64,7 +64,7 @@ func formatDuration(d time.Duration) string {
func (b *Bar) String() string {
termWidth, _, err := term.GetSize(int(os.Stderr.Fd()))
if err != nil {
termWidth = defaultTermWidth
termWidth = 80
}
var pre strings.Builder

View File

@@ -4,16 +4,8 @@ import (
"bufio"
"fmt"
"io"
"os"
"sync"
"time"
"golang.org/x/term"
)
const (
defaultTermWidth = 80
defaultTermHeight = 24
)
type State interface {
@@ -91,11 +83,6 @@ func (p *Progress) Add(key string, state State) {
}
func (p *Progress) render() {
_, termHeight, err := term.GetSize(int(os.Stderr.Fd()))
if err != nil {
termHeight = defaultTermHeight
}
p.mu.Lock()
defer p.mu.Unlock()
@@ -115,9 +102,8 @@ func (p *Progress) render() {
fmt.Fprint(p.w, "\033[1G")
// render progress lines
maxHeight := min(len(p.states), termHeight)
for i := len(p.states) - maxHeight; i < len(p.states); i++ {
fmt.Fprint(p.w, p.states[i].String(), "\033[K")
for i, state := range p.states {
fmt.Fprint(p.w, state.String(), "\033[K")
if i < len(p.states)-1 {
fmt.Fprint(p.w, "\n")
}

View File

@@ -723,7 +723,9 @@ func (m *multiLPath) String() string {
return strings.Join(*m, ", ")
}
func (s *Server) reserveWorstCaseGraph() error {
// TODO(jessegross): This is causing tensor allocation failures with large batches when not offloaded
// to the GPU
/*func (s *Server) reserveWorstCaseGraph() error {
ctx := s.model.Backend().NewContext()
defer ctx.Close()
@@ -766,7 +768,7 @@ func (s *Server) reserveWorstCaseGraph() error {
}
return nil
}
}*/
func (s *Server) loadModel(
ctx context.Context,
@@ -803,10 +805,10 @@ func (s *Server) loadModel(
s.seqs = make([]*Sequence, s.parallel)
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
err = s.reserveWorstCaseGraph()
/*err = s.reserveWorstCaseGraph()
if err != nil {
panic(err)
}
}*/
s.status = llm.ServerStatusReady
s.ready.Done()

View File

@@ -74,7 +74,6 @@ func modelHelper(t testing.TB) model.BytePairEncoding {
t.Fatal(err)
}
types := make([]uint32, len(vocab))
tokens := make([]string, len(vocab))
for token, id := range vocab {
tokens[id] = token
@@ -86,7 +85,7 @@ func modelHelper(t testing.TB) model.BytePairEncoding {
``,
&model.Vocabulary{
Values: tokens,
Types: types,
Types: make([]int32, len(vocab)),
Merges: merges,
},
)

View File

@@ -295,7 +295,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
}
defer bin.Close()
f, _, err := ggml.Decode(bin, 0)
f, _, err := ggml.Decode(bin, 1024)
if err != nil {
return nil, err
}
@@ -457,7 +457,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
return nil, err
}
f, _, err := ggml.Decode(temp, 0)
f, _, err := ggml.Decode(temp, 1024)
if err != nil {
slog.Error(fmt.Sprintf("error decoding ggml: %s\n", err))
return nil, err
@@ -499,7 +499,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
var offset int64
for offset < stat.Size() {
f, n, err := ggml.Decode(blob, 0)
f, n, err := ggml.Decode(blob, 1024)
if errors.Is(err, io.EOF) {
break
} else if err != nil {

View File

@@ -75,7 +75,7 @@ func (m *Model) Capabilities() []model.Capability {
if err == nil {
defer r.Close()
f, _, err := ggml.Decode(r, 0)
f, _, err := ggml.Decode(r, 1024)
if err == nil {
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
capabilities = append(capabilities, model.CapabilityEmbedding)

View File

@@ -73,8 +73,13 @@ type statusCodeRecorder struct {
func (r *statusCodeRecorder) WriteHeader(status int) {
if r._status == 0 {
r._status = status
r.ResponseWriter.WriteHeader(status)
}
r.ResponseWriter.WriteHeader(status)
}
func (r *statusCodeRecorder) Write(b []byte) (int, error) {
r._status = r.status()
return r.ResponseWriter.Write(b)
}
var (

View File

@@ -64,7 +64,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
}
defer blob.Close()
f, _, err := ggml.Decode(blob, 0)
f, _, err := ggml.Decode(blob, 1024)
if err != nil {
return nil, err
}

View File

@@ -18,6 +18,7 @@ import (
"os"
"os/signal"
"path/filepath"
"regexp"
"slices"
"strings"
"syscall"
@@ -1512,6 +1513,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
if req.Messages[0].Role != "system" && m.System != "" {
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
}
msgs = filterThinkTags(msgs, m)
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
if err != nil {
@@ -1640,3 +1642,23 @@ func handleScheduleError(c *gin.Context, name string, err error) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
}
var thinkTagRegexp = regexp.MustCompile(`<think>(?s).*?</think>(\n)*`)
func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" {
finalUserIndex := -1
for i, msg := range msgs {
if msg.Role == "user" {
finalUserIndex = i
}
}
for i, msg := range msgs {
if msg.Role == "assistant" && i < finalUserIndex {
msgs[i].Content = thinkTagRegexp.ReplaceAllString(msg.Content, "")
}
}
}
return msgs
}

View File

@@ -299,9 +299,6 @@ func TestGenerateChat(t *testing.T) {
{Role: "user", Content: "Hello!"},
},
Stream: &stream,
Options: map[string]any{
"num_ctx": 1024,
},
})
if w.Code != http.StatusOK {
@@ -324,9 +321,6 @@ func TestGenerateChat(t *testing.T) {
{Role: "user", Content: "Hello!"},
},
Stream: &stream,
Options: map[string]any{
"num_ctx": 1024,
},
})
if w.Code != http.StatusOK {
@@ -350,9 +344,6 @@ func TestGenerateChat(t *testing.T) {
{Role: "user", Content: "Help me write tests."},
},
Stream: &stream,
Options: map[string]any{
"num_ctx": 1024,
},
})
if w.Code != http.StatusOK {

View File

@@ -15,6 +15,7 @@ import (
"net/http/httptest"
"os"
"path/filepath"
"reflect"
"sort"
"strings"
"testing"
@@ -746,3 +747,128 @@ func TestNormalize(t *testing.T) {
})
}
}
func TestFilterThinkTags(t *testing.T) {
type testCase struct {
msgs []api.Message
want []api.Message
model *Model
}
testCases := []testCase{
{
msgs: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
{Role: "user", Content: "What is the answer?"},
},
want: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "abc"},
{Role: "user", Content: "What is the answer?"},
},
model: &Model{
Config: ConfigV2{
ModelFamily: "qwen3",
},
},
},
// with newlines inside the think tag aned newlines after
{
msgs: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "<think>Thinking... \n\nabout \nthe answer</think>\n\nabc\ndef"},
{Role: "user", Content: "What is the answer?"},
},
want: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "abc\ndef"},
{Role: "user", Content: "What is the answer?"},
},
model: &Model{
Config: ConfigV2{
ModelFamily: "qwen3",
},
},
},
// should leave thinking tags if it's after the last user message
{
msgs: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "<think>Thinking...</think>after"},
{Role: "user", Content: "What is the answer?"},
{Role: "assistant", Content: "<think>thinking again</think>hjk"},
{Role: "assistant", Content: "<think>thinking yet again</think>hjk"},
},
want: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "after"},
{Role: "user", Content: "What is the answer?"},
{Role: "assistant", Content: "<think>thinking again</think>hjk"},
{Role: "assistant", Content: "<think>thinking yet again</think>hjk"},
},
model: &Model{
Config: ConfigV2{
ModelFamily: "qwen3",
},
},
},
{
// shouldn't strip anything because the model family isn't one of the hardcoded ones
msgs: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
{Role: "user", Content: "What is the answer?"},
},
want: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
{Role: "user", Content: "What is the answer?"},
},
model: &Model{
Config: ConfigV2{
ModelFamily: "llama3",
},
},
},
{
// deepseek-r1:-prefixed model
msgs: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
{Role: "user", Content: "What is the answer?"},
},
want: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "abc"},
{Role: "user", Content: "What is the answer?"},
},
model: &Model{
Name: "registry.ollama.ai/library/deepseek-r1:latest",
ShortName: "deepseek-r1:7b",
Config: ConfigV2{},
},
},
}
for i, tc := range testCases {
filtered := filterThinkTags(tc.msgs, tc.model)
if !reflect.DeepEqual(filtered, tc.want) {
t.Errorf("messages differ for case %d:", i)
for i := range tc.want {
if i >= len(filtered) {
t.Errorf(" missing message %d: %+v", i, tc.want[i])
continue
}
if !reflect.DeepEqual(filtered[i], tc.want[i]) {
t.Errorf(" message %d:\n want: %+v\n got: %+v", i, tc.want[i], filtered[i])
}
}
if len(filtered) > len(tc.want) {
for i := len(tc.want); i < len(filtered); i++ {
t.Errorf(" extra message %d: %+v", i, filtered[i])
}
}
}
}
}

View File

@@ -81,6 +81,10 @@ func InitScheduler(ctx context.Context) *Scheduler {
// context must be canceled to decrement ref count and release the runner
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
if opts.NumCtx < 4 {
opts.NumCtx = 4
}
req := &LlmRequest{
ctx: c,
model: model,
@@ -110,11 +114,6 @@ func (s *Scheduler) Run(ctx context.Context) {
}()
}
const (
defaultContextLength = 4096
smallGpuContextLength = 2048
)
func (s *Scheduler) processPending(ctx context.Context) {
for {
select {
@@ -167,17 +166,6 @@ func (s *Scheduler) processPending(ctx context.Context) {
gpus = s.getGpuFn()
}
if pending.origNumCtx == -1 {
if len(gpus) == 1 && gpus[0].Library != "cpu" && gpus[0].TotalMemory <= 4096*1024*1024 {
slog.Info("GPU is small, limiting default context window", "num_ctx", smallGpuContextLength)
pending.opts.NumCtx = smallGpuContextLength
pending.origNumCtx = smallGpuContextLength
} else {
pending.opts.NumCtx = defaultContextLength
pending.origNumCtx = defaultContextLength
}
}
if envconfig.MaxRunners() <= 0 {
// No user specified MaxRunners, so figure out what automatic setting to use
// If all GPUs have reliable free memory reporting, defaultModelsPerGPU * the number of GPUs
@@ -453,10 +441,9 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
estimatedVRAM: llama.EstimatedVRAM(),
estimatedTotal: llama.EstimatedTotal(),
loading: true,
refCount: 1,
}
runner.numParallel = numParallel
runner.refMu.Lock()
runner.refMu.Lock() // hold lock until running or aborted
s.loadedMu.Lock()
s.loaded[req.model.ModelPath] = runner
@@ -467,13 +454,13 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
defer runner.refMu.Unlock()
if err = llama.WaitUntilRunning(req.ctx); err != nil {
slog.Error("error loading llama server", "error", err)
runner.refCount--
req.errCh <- err
slog.Debug("triggering expiration for failed load", "model", runner.modelPath)
s.expiredCh <- runner
return
}
slog.Debug("finished setting up runner", "model", req.model.ModelPath)
runner.refCount++
runner.loading = false
go func() {
<-req.ctx.Done()
@@ -491,7 +478,12 @@ func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) {
}
predMap := map[predKey]uint64{} // Sum up the total predicted usage per GPU for all runners
s.loadedMu.Lock()
runners := make([]*runnerRef, 0, len(s.loaded))
for _, r := range s.loaded {
runners = append(runners, r)
}
s.loadedMu.Unlock()
for _, r := range runners {
r.refMu.Lock()
if r.llama != nil {
for _, gpu := range allGpus {
@@ -502,7 +494,6 @@ func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) {
}
r.refMu.Unlock()
}
s.loadedMu.Unlock()
// Now that we've summed up all the GPU usage predictions across all the loaded runners, update the gpu list
for i := range allGpus {
@@ -549,10 +540,8 @@ func (s *Scheduler) filterGPUsWithoutLoadingModels(allGpus discover.GpuInfoList)
// TODO consolidate sched_types.go
type runnerRef struct {
refMu sync.Mutex
// refCond sync.Cond // Signaled on transition from 1 -> 0 refCount
refMu sync.Mutex
refCount uint // prevent unloading if > 0
// unloading bool // set to true when we are trying to unload the runner
llama llm.LlamaServer
loading bool // True only during initial load, then false forever
@@ -823,8 +812,8 @@ func (s *Scheduler) unloadAllRunners() {
func (s *Scheduler) expireRunner(model *Model) {
s.loadedMu.Lock()
defer s.loadedMu.Unlock()
runner, ok := s.loaded[model.ModelPath]
s.loadedMu.Unlock()
if ok {
runner.refMu.Lock()
runner.expiresAt = time.Now()

View File

@@ -148,7 +148,6 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
}
b.req.opts.NumCtx = 4096
b.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
return b
}