Merge pull request #9661 from ollama/gemma

engine: add gemma support
This commit is contained in:
Michael Yang 2025-03-11 15:07:50 -07:00 committed by GitHub
commit aee28501b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 1735 additions and 55 deletions

View File

@ -13,8 +13,13 @@ import (
)
type ModelParameters struct {
Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"`
Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"`
TextModel TextParameters `json:"text_config"`
}
type TextParameters struct {
VocabSize uint32 `json:"vocab_size"`
}
type AdapterParameters struct {
@ -185,6 +190,8 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
conv = &gemmaModel{}
case "Gemma2ForCausalLM":
conv = &gemma2Model{}
case "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration":
conv = &gemma3Model{Architecture: p.Architectures[0]}
case "Phi3ForCausalLM":
conv = &phi3Model{}
case "Qwen2ForCausalLM":
@ -213,7 +220,14 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
}
vocabSize := int(p.VocabSize)
if vocabSize == 0 {
tVocabSize := int(p.TextModel.VocabSize)
vocabSize = tVocabSize
}
switch {
case vocabSize == 0:
slog.Warn("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens))
case vocabSize > len(t.Vocabulary.Tokens):
slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
for i := range vocabSize - len(t.Vocabulary.Tokens) {

View File

@ -45,7 +45,7 @@ func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
func (p *gemmaModel) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
for _, t := range ts {
if strings.HasSuffix(t.Name(), "_norm.weight") {
if !strings.HasPrefix(t.Name(), "v.") && strings.HasSuffix(t.Name(), "_norm.weight") {
t.SetRepacker(p.addOne)
}

142
convert/convert_gemma3.go Normal file
View File

@ -0,0 +1,142 @@
package convert
import (
"cmp"
"github.com/ollama/ollama/fs/ggml"
)
type gemma3Model struct {
gemmaModel
Architecture string
TextModel struct {
HeadDim uint32 `json:"head_dim"`
HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
SlidingWindow uint32 `json:"sliding_window"`
} `json:"text_config"`
VisionModel struct {
NumAttentionHeads uint32 `json:"num_attention_heads"` // attention.head_count 16
LayerNormEpsilon float32 `json:"layer_norm_eps"` // attention.layer_norm_epsilon 1e-05
NumHiddenLayers uint32 `json:"num_hidden_layers"` // block_count 32
HiddenSize uint32 `json:"hidden_size"` // embedding_length 1280
IntermediateSize uint32 `json:"intermediate_size"` // feed_forward_length 5120
ImageSize uint32 `json:"image_size"` // image_size 560
NumChannels uint32 `json:"num_channels"` // num_channels 3
PatchSize uint32 `json:"patch_size"` // patch_size 14
} `json:"vision_config"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
RopeLocalTheta float32 `json:"rope_local_base_freq"`
RopeGlobalTheta float32 `json:"rope_global_base_freq"`
SlidingWindow uint32 `json:"sliding_window"`
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
}
const (
gemma4BLayerCount = 34
gemma12BLayerCount = 48
gemma27BLayerCount = 62
)
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3"
numBlocks := cmp.Or(p.HiddenLayers, p.TextModel.HiddenLayers)
kv["gemma3.block_count"] = numBlocks
var (
numHeads uint32
numKVHeads uint32
)
switch numBlocks {
case gemma4BLayerCount:
numHeads = 8
numKVHeads = 4
case gemma12BLayerCount:
numHeads = 16
numKVHeads = 8
case gemma27BLayerCount:
numHeads = 32
numKVHeads = 16
default:
numHeads = p.NumAttentionHeads
numKVHeads = p.NumKeyValueHeads
}
kv["gemma3.attention.head_count"] = numHeads
kv["gemma3.attention.head_count_kv"] = numKVHeads
switch p.Architecture {
case "Gemma3ForCausalLM":
kv["gemma3.context_length"] = p.MaxPositionEmbeddings
kv["gemma3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["gemma3.attention.key_length"] = p.HeadDim
kv["gemma3.attention.value_length"] = p.HeadDim
kv["gemma3.attention.sliding_window"] = p.SlidingWindow
kv["gemma3.final_logit_softcapping"] = cmp.Or(p.FinalLogitSoftcap, 30)
kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0)
kv["gemma3.rope.global.freq_base"] = cmp.Or(p.RopeGlobalTheta, 1000000.0)
kv["gemma3.embedding_length"] = p.HiddenSize
kv["gemma3.feed_forward_length"] = p.IntermediateSize
default:
kv["gemma3.context_length"] = cmp.Or(p.MaxPositionEmbeddings, 8192)
kv["gemma3.embedding_length"] = p.TextModel.HiddenSize
kv["gemma3.feed_forward_length"] = p.TextModel.IntermediateSize
kv["gemma3.attention.sliding_window"] = p.TextModel.SlidingWindow
kv["gemma3.vision.block_count"] = p.VisionModel.NumHiddenLayers
kv["gemma3.vision.embedding_length"] = p.VisionModel.HiddenSize
kv["gemma3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
kv["gemma3.vision.image_size"] = p.VisionModel.ImageSize
kv["gemma3.vision.patch_size"] = p.VisionModel.PatchSize
kv["gemma3.vision.num_channels"] = cmp.Or(p.VisionModel.NumChannels, 3)
kv["gemma3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
kv["gemma3.vision.attention.layer_norm_epsilon"] = cmp.Or(p.VisionModel.LayerNormEpsilon, 1e-6)
kv["gemma3.attention.key_length"] = cmp.Or(p.TextModel.HeadDim, 256)
kv["gemma3.attention.value_length"] = cmp.Or(p.TextModel.HeadDim, 256)
}
if p.MultiModalTokensPerImage > 0 {
kv["gemma3.mm.tokens_per_image"] = p.MultiModalTokensPerImage
}
return kv
}
func (p *gemma3Model) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"vision_tower.vision_model.embeddings", "v",
"vision_tower.vision_model", "v",
"vision_model.vision_model.embeddings", "v",
"vision_model.vision_model", "v",
"language_model.", "",
"model.layers", "blk",
"encoder.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.q_norm", "attn_q_norm",
"self_attn.k_proj", "attn_k",
"self_attn.k_norm", "attn_k_norm",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"self_attn.out_proj", "attn_output",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
"post_attention_layernorm", "post_attention_norm",
"pre_feedforward_layernorm", "ffn_norm",
"post_feedforward_layernorm", "post_ffw_norm",
"input_projection_weight", "input_projection.weight",
"multi_modal_projector", "mm",
}
}

View File

@ -6,7 +6,9 @@ import (
"errors"
"fmt"
"io/fs"
"log/slog"
"os"
"reflect"
"slices"
"google.golang.org/protobuf/proto"
@ -15,6 +17,8 @@ import (
)
func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
slog.Debug("using spm vocabulary")
ast, err := parseAdditionalSpecialTokens(fsys)
if err != nil {
return nil, err
@ -43,10 +47,19 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
v.Types = append(v.Types, int32(t))
default:
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
if slices.Contains(ast, piece.GetPiece()) {
// temporary fix to handle gemma3 broken configs
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>"}, piece.GetPiece()) {
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
}
for _, t := range ast {
if t.Content == piece.GetPiece() {
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
break
}
}
v.Types = append(v.Types, tt)
}
}
@ -78,10 +91,16 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
return cmp.Compare(i.id, j.id)
})
n := len(v.Tokens)
for i, t := range ts {
if t.id != i+n {
return nil, fmt.Errorf("invalid token id: %d", t.id)
for _, t := range ts {
if t.id < len(v.Tokens) {
if v.Tokens[t.id] == t.content {
slog.Warn("tokenizer", "duplicate token", t.content, "id", t.id)
continue
}
return nil, fmt.Errorf("token mismatch: %s != %s at pos [%d]", t.content, v.Tokens[t.id], t.id)
}
if t.id != len(v.Tokens) {
return nil, fmt.Errorf("invalid token id: [%d] as pos [%d]", t.id, len(v.Tokens))
}
v.Tokens = append(v.Tokens, t.content)
@ -92,7 +111,15 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
return &v, nil
}
func parseAdditionalSpecialTokens(fsys fs.FS) ([]string, error) {
type specialToken struct {
Content string `json:"content"`
Lstrip bool `json:"lstrip"`
Normalized bool `json:"normalized"`
Rstrip bool `json:"rstrip"`
SingleWord bool `json:"single_word"`
}
func parseAdditionalSpecialTokens(fsys fs.FS) ([]specialToken, error) {
f, err := fsys.Open("special_tokens_map.json")
if errors.Is(err, os.ErrNotExist) {
return nil, nil
@ -102,12 +129,43 @@ func parseAdditionalSpecialTokens(fsys fs.FS) ([]string, error) {
defer f.Close()
var m struct {
AdditionalSpecialTokens []string `json:"additional_special_tokens"`
AdditionalSpecialTokens any `json:"additional_special_tokens"`
}
if err := json.NewDecoder(f).Decode(&m); err != nil {
return nil, err
}
return m.AdditionalSpecialTokens, nil
var ast []specialToken
switch st := m.AdditionalSpecialTokens.(type) {
case []string:
for _, s := range st {
ast = append(ast, specialToken{Content: s})
}
case []any:
for _, s := range st {
// marshal and unmarshal the object to get the special token
tMap := s.(map[string]any)
data, err := json.Marshal(tMap)
if err != nil {
return nil, err
}
var token specialToken
err = json.Unmarshal(data, &token)
if err != nil {
return nil, err
}
ast = append(ast, token)
}
default:
slog.Warn("special token", "unknown token", reflect.TypeOf(st))
}
slog.Debug("spm tokenizer", "additional tokens", ast)
return ast, nil
}

View File

@ -124,6 +124,19 @@ func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
return s
}
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
r := keyValue(kv, key, &array{})
s := make([]float32, r.size)
for i := range r.size {
s[i] = float32(r.values[i].(float32))
}
return s
}
func (kv KV) OllamaEngineRequired() bool {
return kv.Architecture() == "gemma3"
}
func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key
@ -476,7 +489,7 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
// vocab graph
4*batch*(embedding+vocab)+embedding*vocab*105/128,
)
case "gemma", "gemma2":
case "gemma", "gemma2", "gemma3":
fullOffload = max(
4*batch*(embedding+vocab),
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),

View File

@ -21,9 +21,10 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
type Causal struct {
DType ml.DType
Capacity int32
causal bool
windowSize int32
opts CausalOptions
// config controls mostly backend-specific optimizations
config *ml.CacheConfig
@ -79,7 +80,6 @@ type cellRange struct {
func NewCausalCache(shift shiftFn) *Causal {
return &Causal{
causal: true,
windowSize: math.MaxInt32,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
@ -90,7 +90,6 @@ func NewCausalCache(shift shiftFn) *Causal {
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
return &Causal{
causal: true,
windowSize: windowSize,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
@ -145,6 +144,7 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
c.curBatchSize = len(opts.Positions)
c.curSequences = opts.Sequences
c.curPositions = opts.Positions
c.opts.Except = nil
var err error
c.curLoc, err = c.findStartLoc()
@ -235,9 +235,10 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
mask := make([]float32, batchSize*length)
for i := range c.curBatchSize {
enabled := !slices.Contains(c.opts.Except, i)
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
(c.causal && c.cells[j].pos > c.curPositions[i]) ||
(enabled && c.cells[j].pos > c.curPositions[i]) ||
c.cells[j].pos < c.curPositions[i]-c.windowSize {
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
}
@ -404,15 +405,16 @@ func (c *Causal) SetLayer(layer int) {
c.curLayer = layer
}
// SetCausal enables or disables causal mask generation for subsequent calls to Get.
// This state carries over to future forward passes. The default value is true.
//
// ctx may be set to nil if this is called from outside of a forward pass, for
// example, when initializing the cache.
func (c *Causal) SetCausal(ctx ml.Context, causal bool) {
if c.causal != causal {
c.causal = causal
type CausalOptions struct {
// Enabled controls whether the causal mask is generated for a particular index in a batch
Except []int
}
// SetCausal disables causal mask generation for a particular range of indicies in
// the current batch for subsequent calls to Get. The state resets for the next forward pass.
func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
if !slices.Equal(c.opts.Except, opts.Except) {
c.opts = opts
if ctx != nil {
var err error
c.curMask, err = c.buildMask(ctx)

View File

@ -441,11 +441,19 @@ func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim uint32, base, scale float32) ml.Tensor {
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor {
panic("not implemented")
}
@ -495,6 +503,10 @@ func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}

View File

@ -0,0 +1,33 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Yang <mxyng@pm.me>
Date: Sun, 9 Mar 2025 14:44:16 -0700
Subject: [PATCH] ollama debug tensor
---
ggml/src/ggml-cpu/ggml-cpu.c | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
index 2f606d82..ec60e8fc 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
@@ -11,6 +11,8 @@
#include "ggml-threading.h"
#include "ggml.h"
+#include "ollama-debug.h"
+
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
@@ -14103,6 +14105,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
ggml_compute_forward(&params, node);
+#ifdef OLLAMA_DEBUG
+ ollama_debug(node, true);
+#endif
+
if (state->ith == 0 && cplan->abort_callback &&
cplan->abort_callback(cplan->abort_callback_data)) {
atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed);

View File

@ -271,7 +271,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
var llamaModel *llama.Model
var textProcessor model.TextProcessor
if envconfig.NewEngine() {
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
textProcessor, err = model.NewTextProcessor(modelPath)
if err != nil {
// To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner

View File

@ -19,6 +19,7 @@ type Config interface {
Strings(string, ...[]string) []string
Uints(string, ...[]uint32) []uint32
Floats(string, ...[]float32) []float32
}
type Backend interface {
@ -134,8 +135,10 @@ type Tensor interface {
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
Scale(ctx Context, s float64) Tensor
AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, base, scale float32) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
Tanh(ctx Context) Tensor
GELU(ctx Context) Tensor
@ -145,6 +148,7 @@ type Tensor interface {
View(ctx Context, offset int, shape ...int) Tensor
Permute(ctx Context, shape ...int) Tensor
Contiguous(ctx Context) Tensor
Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor
Pad(ctx Context, shape ...int) Tensor
Unpad(ctx Context, shape ...int) Tensor

View File

@ -240,11 +240,22 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
switch {
case contains(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"):
createTensor(tensor{source: t}, input.bts)
if _, ok := meta.Tensors().GroupLayers()["output"]; !ok && t.Name == "token_embd.weight" {
createTensor(tensor{source: t, target: "output.weight"}, output.bts)
}
case contains(t.Name, "cls", "output", "output_norm"):
createTensor(tensor{source: t}, output.bts)
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
// TODO: assign vision tensors to the gpu if possible
createTensor(tensor{source: t}, input.bts)
createTensor(tensor{source: t}, output.bts)
case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"):
// these tensors should be repeated per layer
for i, layer := range layers {
createTensor(tensor{
source: t,
target: "blk." + strconv.Itoa(i) + "." + t.Name,
}, layer.bts)
}
default:
layerIndex := -1
if fields := strings.FieldsFunc(t.Name, func(r rune) bool { return !unicode.IsNumber(r) }); len(fields) > 0 {
@ -256,14 +267,8 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
if layerIndex >= 0 {
createTensor(tensor{source: t}, layers[layerIndex].bts)
} else {
// this is a repeating tensor that doesn't explicitly associated with a layer so
// duplicate it for each layer
for i, layer := range layers {
createTensor(tensor{
source: t,
target: "blk." + strconv.Itoa(i) + "." + t.Name,
}, layer.bts)
}
// load all other tensors on the cpu
createTensor(tensor{source: t}, input.bts)
}
}
}
@ -352,7 +357,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
if C.ggml_backend_is_cpu(b) {
// set number of threads for cpu backend
C.ggml_backend_cpu_set_n_threads(b, C.int(params.NumThreads))
C.ggml_backend_cpu_set_n_threads(b, C.int(Threads(params.NumThreads)))
}
}
@ -893,10 +898,13 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
}
const (
ropeTypeNorm C.int = iota
ropeTypeNorm C.int = 0
ropeTypeNeox C.int = 2
ropeTypeMrope C.int = 8
ropeTypeVision C.int = 24
)
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
if ropeFactors == nil {
ropeFactors = &Tensor{b: t.b}
}
@ -911,8 +919,8 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
t: C.ggml_rope_ext(
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
C.int(ropeDim),
131072, // YaRN n_ctx_train
ropeTypeNorm, // ROPE_TYPE_NORM
C.int(ropeType),
131072, // YaRN n_ctx_train
C.float(ropeBase),
C.float(ropeScale),
0., // YaRN ext_factor
@ -944,6 +952,27 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
}
}
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_pool_2d(ctx.(*Context).ctx, t.t, C.GGML_OP_POOL_AVG, C.int(k), C.int(k), C.int(s), C.int(s), C.float(p), C.float(p)),
}
}
func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
var tt *C.struct_ggml_tensor
switch len(strides) {
case 0:
tt = C.ggml_set_1d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset))
case 1:
tt = C.ggml_set_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0]))
default:
panic("unsupported number of dimensions")
}
return &Tensor{b: t.b, t: tt}
}
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
var kqMask *C.struct_ggml_tensor
if mask != nil {

View File

@ -0,0 +1,11 @@
#include "ggml.h"
#ifdef __cplusplus
extern "C" {
#endif
void ollama_debug(const struct ggml_tensor *tensor, bool verbose);
#ifdef __cplusplus
}
#endif

View File

@ -0,0 +1,6 @@
//go:build debug
package cpu
// #cgo CPPFLAGS: -DOLLAMA_DEBUG
import "C"

View File

@ -11,6 +11,8 @@
#include "ggml-threading.h"
#include "ggml.h"
#include "ollama-debug.h"
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
@ -14103,6 +14105,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
ggml_compute_forward(&params, node);
#ifdef OLLAMA_DEBUG
ollama_debug(node, true);
#endif
if (state->ith == 0 && cplan->abort_callback &&
cplan->abort_callback(cplan->abort_callback_data)) {
atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed);

115
ml/backend/ggml/ggml/src/ollama-debug.c vendored Normal file
View File

@ -0,0 +1,115 @@
#include <string.h>
#include "ollama-debug.h"
static int mul(int64_t *dims, int ndims) {
int result = 1;
for (int i = 0; i < ndims; i++) {
result *= dims[i];
}
return result;
}
static void repeat(char c, int n) {
for (int i = 0; i < n; i++) {
fprintf(stderr, "%c", c);
}
}
static void print_tensor(const void *tensor, void (*cb)(const void *, int),
int shape,
int64_t *dims, int ndims, int stride,
int nitems, int pad) {
fprintf(stderr, "[");
for (int i = 0; i < dims[0]; i++) {
if (i >= nitems && i < dims[0] - nitems) {
fprintf(stderr, "... (%lld more), ", dims[0] - 2 * nitems);
int skip = dims[0] - 2 * nitems;
if (ndims > 1) {
stride += mul(dims + 1, ndims - 1) * skip;
repeat('\n', ndims - 1);
repeat(' ', shape - ndims + 1 + pad);
}
i += skip - 1;
} else if (ndims > 1) {
print_tensor(tensor, cb, shape, dims + 1, ndims - 1, stride,
nitems, pad);
stride += mul(dims + 1, ndims - 1);
if (i < dims[0] - 1) {
fprintf(stderr, ", ");
repeat('\n', ndims - 1);
repeat(' ', shape - ndims + 1 + pad);
}
} else {
cb(tensor, stride + i);
if (i < dims[0] - 1) {
fprintf(stderr, ", ");
}
}
}
fprintf(stderr, "]");
}
static void print_tensor_f16(const void *tensor, int i) {
float value = ggml_fp16_to_fp32(((const ggml_fp16_t *)tensor)[i]);
fprintf(stderr, "%s%f", value < 0 ? "" : " ", value);
}
static void print_tensor_f32(const void *tensor, int i) {
float value = ((const float *)tensor)[i];
fprintf(stderr, "%s%f", value < 0 ? "" : " ", value);
}
static void print_tensor_i32(const void *tensor, int i) {
int32_t value = ((const int32_t *)tensor)[i];
fprintf(stderr, "%s%d", value < 0 ? "" : " ", value);
}
static void ollama_debug_tensor(const struct ggml_tensor *tensor, bool verbose, const char *prefix, int indent) {
fprintf(stderr, "%s%s %s (%s): [%lld %lld %lld %lld]\n", prefix, tensor->name,
ggml_op_name(tensor->op), ggml_type_name(tensor->type), tensor->ne[0],
tensor->ne[1], tensor->ne[2], tensor->ne[3]);
if (!verbose) {
return;
}
for (int i = 0; i < indent; i++) {
fprintf(stderr, " ");
}
switch (tensor->type) {
case GGML_TYPE_F16:
print_tensor(ggml_get_data(tensor), print_tensor_f16, ggml_n_dims(tensor),
(int64_t *)tensor->ne, ggml_n_dims(tensor), 0, 3, indent);
break;
case GGML_TYPE_F32:
print_tensor(ggml_get_data(tensor), print_tensor_f32, ggml_n_dims(tensor),
(int64_t *)tensor->ne, ggml_n_dims(tensor), 0, 3, indent);
break;
case GGML_TYPE_I32:
print_tensor(ggml_get_data(tensor), print_tensor_i32, ggml_n_dims(tensor),
(int64_t *)tensor->ne, ggml_n_dims(tensor), 0, 3, indent);
break;
default:
fprintf(stderr, "<unsupported type>\n");
return;
}
fprintf(stderr, "\n");
}
void ollama_debug(const struct ggml_tensor *tensor, bool verbose) {
ollama_debug_tensor(tensor, verbose, ">>> ", 4);
for (int i = 0; i < GGML_MAX_SRC && tensor->src[i] != NULL; ++i) {
char src[8];
const int n = snprintf(src, sizeof(src), " src%d ", i);
if (n >= sizeof(src)) {
src[sizeof(src) - 1] = '\0';
}
ollama_debug_tensor(tensor->src[i], verbose, src, 4);
}
}

View File

@ -0,0 +1,7 @@
//go:build !debug
package ggml
func Threads(n int) int {
return n
}

View File

@ -0,0 +1,7 @@
//go:build debug
package ggml
func Threads(_ int) int {
return 1
}

View File

@ -0,0 +1,220 @@
package gemma2
import (
"math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Options struct {
hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int
eps, ropeBase, ropeScale float32
attnLogitSoftcap float32
finalLogitSoftcap float32
largeModelScaling bool
}
type Model struct {
model.Base
model.SentencePieceModel
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"` // just set to token_embd?
*Options
}
const (
gemma27BLayerCount = 46
)
func New(c ml.Config) (model.Model, error) {
m := Model{
SentencePieceModel: model.NewSentencePieceModel(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Uints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
},
),
Layers: make([]Layer, c.Uint("block_count")),
Options: &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length")),
attnValLen: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base", 10000.0),
ropeScale: c.Float("rope.freq_scale", 1.0),
attnLogitSoftcap: c.Float("attn_logit_softcapping"),
finalLogitSoftcap: c.Float("final_logit_softcapping"),
},
}
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
m.Cache.SetConfig(ml.CacheConfig{})
return &m, nil
}
type SelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeType := uint32(2)
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
} else {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
}
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
cache.Put(ctx, k, v)
k, v, mask := cache.Get(ctx)
q = q.Permute(ctx, 0, 2, 1, 3)
k = k.Permute(ctx, 0, 2, 1, 3)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := k.Mulmat(ctx, q)
// logit softcap
kq = kq.Scale(ctx, 1.0/float64(opts.attnLogitSoftcap))
kq = kq.Tanh(ctx)
kq = kq.Scale(ctx, float64(opts.attnLogitSoftcap))
kq = kq.Add(ctx, mask)
kq = kq.Softmax(ctx)
kqv := v.Mulmat(ctx, kq)
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv)
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil
}
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *SelfAttention
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
return hiddenState.Add(ctx, residual)
}
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
if len(m.Layers) == gemma27BLayerCount {
m.Options.largeModelScaling = true
}
for i, layer := range m.Layers {
cacheType := i % 2
m.Cache.SetLayer(i)
wc := m.Cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType)
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
hiddenState = m.Output.Forward(ctx, hiddenState)
// final logit softcap
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
hiddenState = hiddenState.Tanh(ctx)
hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap))
return hiddenState.Rows(ctx, outputs), nil
}
func init() {
model.Register("gemma2", New)
}

View File

@ -0,0 +1,173 @@
package gemma3
import (
"bytes"
"encoding/binary"
"hash/fnv"
"image"
"math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.SentencePieceModel
*VisionModel `gguf:"v,vision"`
*TextModel
*MultiModalProjector `gguf:"mm"`
ImageProcessor
}
var _ model.MultimodalProcessor = (*Model)(nil)
type MultiModalProjector struct {
SoftEmbNorm *nn.RMSNorm `gguf:"mm_soft_emb_norm"`
InputProjection *nn.Linear `gguf:"mm_input_projection"`
tokensPerImage int
}
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, imageSize, patchSize int, eps float32) ml.Tensor {
l := visionOutputs.Dim(0)
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
patchesPerImage := imageSize / patchSize
visionOutputs = visionOutputs.Reshape(ctx, patchesPerImage, patchesPerImage, l)
kernelSize := patchesPerImage / int(math.Sqrt(float64(p.tokensPerImage)))
visionOutputs = visionOutputs.AvgPool2D(ctx, kernelSize, kernelSize, 0)
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0)*visionOutputs.Dim(1), l)
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
// TODO: inputProjection must be transposed since they're incompatible with visionOutputs
visionOutputs = p.InputProjection.Weight.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mulmat(ctx, visionOutputs)
return visionOutputs
}
func New(c ml.Config) (model.Model, error) {
m := Model{
SentencePieceModel: model.NewSentencePieceModel(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Uints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
EOS: int32(1),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOT: int32(106),
AddEOT: c.Bool("tokenizer.ggml.add_eot_token", false),
},
),
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),
MultiModalProjector: &MultiModalProjector{
tokensPerImage: int(c.Uint("mm_tokens_per_image", 256)),
},
}
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
return &m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
f32s, err := m.ImageProcessor.ProcessImage(image)
if err != nil {
return nil, err
}
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
m.ImageProcessor.numChannels,
)
if err != nil {
return nil, err
}
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
return visionOutputs, nil
}
type imageToken struct {
embedding ml.Tensor
index int
}
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
var result []input.Input
fnvHash := fnv.New64a()
for _, inp := range inputs {
if inp.Multimodal == nil {
result = append(result, inp)
} else {
imageInputs := []input.Input{
{Token: 108}, // "\n\n"
{Token: 255999}, // "<start_of_image>""
}
result = append(result, imageInputs...)
// add image embeddings
inputMultimodal := inp.Multimodal.(ml.Tensor)
for i := range inputMultimodal.Dim(1) {
fnvHash.Reset()
binary.Write(fnvHash, binary.NativeEndian, inp.MultimodalHash)
fnvHash.Write([]byte{byte(i)})
imageToken := imageToken{embedding: inputMultimodal, index: i}
result = append(result, input.Input{Multimodal: imageToken, MultimodalHash: fnvHash.Sum64()})
}
result = append(result,
input.Input{Token: 256000}, // <end_of_image>
input.Input{Token: 108}, // "\n\n"
)
}
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
}
func init() {
model.Register("gemma3", New)
}

View File

@ -0,0 +1,254 @@
package gemma3
import (
"math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type TextOptions struct {
hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int
eps, ropeScale float32
ropeLocalBase, ropeGlobalBase float32
finalLogitSoftcap float32
largeModelScaling bool
}
type TextModel struct {
model.Base
model.SentencePieceModel
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []TextLayer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*TextOptions
}
const (
gemmaGlobalCacheCount = 6
gemma27BLayerCount = 62
)
const (
cacheTypeSWA = iota
cacheTypeCausal
)
func newTextModel(c ml.Config) *TextModel {
numBlocks := int(c.Uint("block_count"))
m := TextModel{
SentencePieceModel: model.NewSentencePieceModel(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Uints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
},
),
Layers: make([]TextLayer, numBlocks),
TextOptions: &TextOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length", 256)),
attnValLen: int(c.Uint("attention.value_length", 256)),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
ropeScale: c.Float("rope.freq_scale", 1.0),
finalLogitSoftcap: c.Float("final_logit_softcapping", 30.0),
},
}
if numBlocks == gemma27BLayerCount {
m.largeModelScaling = true
}
return &m
}
type TextSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeType := uint32(2)
ropeBase := opts.ropeLocalBase
if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = opts.ropeGlobalBase
}
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
} else {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
}
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
scaleFactor := 1.0
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv)
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeBase := m.TextOptions.ropeLocalBase
if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = m.TextOptions.ropeGlobalBase
}
return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil
}
type TextMLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type TextLayer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *TextSelfAttention
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *TextMLP
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
}
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positionIDs, cache, opts)
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
return hiddenState.Add(ctx, residual)
}
func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex) []int {
var embedding ml.Tensor
var src, dst, length int
var except []int
for _, image := range multimodal {
imageToken := image.Multimodal.(imageToken)
imageSrc := imageToken.index
imageDst := image.Index
if embedding == nil {
embedding = imageToken.embedding
src = imageSrc
dst = imageDst
length = 1
} else if embedding == imageToken.embedding && imageSrc+1 == src && imageDst+1 == dst {
src = imageSrc
dst = imageDst
length++
} else if embedding == imageToken.embedding && src+length == imageSrc && dst+length == imageDst {
length++
} else {
visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
embedding = imageToken.embedding
src = imageSrc
dst = imageDst
length = 1
}
except = append(except, imageDst)
}
if embedding != nil {
visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
}
return except
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal)
for i, layer := range m.Layers {
// gemma alternates between the sliding window (local) and causal (global)
// kv cache every 6 layers
cacheType := cacheTypeSWA
if (i+1)%gemmaGlobalCacheCount == 0 {
cacheType = cacheTypeCausal
}
cache.SetLayer(i)
wc := cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType)
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
}
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
hiddenState = m.Output.Forward(ctx, hiddenState)
// final logit softcap
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextOptions.finalLogitSoftcap))
hiddenState = hiddenState.Tanh(ctx)
return hiddenState.Scale(ctx, float64(m.TextOptions.finalLogitSoftcap))
}

View File

@ -0,0 +1,127 @@
package gemma3
import (
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
var batchSize int = 1
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
headDim := opts.hiddenSize / opts.numHeads
query := sa.Query.Forward(ctx, hiddenState)
key := sa.Key.Forward(ctx, hiddenState)
value := sa.Value.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
hiddenState = sa.Output.Forward(ctx, attention)
return hiddenState
}
type VisionMLP struct {
FC1 *nn.Linear `gguf:"fc1"`
FC2 *nn.Linear `gguf:"fc2"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
hiddenState = mlp.FC1.Forward(ctx, hiddenState).GELU(ctx)
hiddenState = mlp.FC2.Forward(ctx, hiddenState)
return hiddenState
}
type VisionEncoderLayer struct {
LayerNorm1 *nn.LayerNorm `gguf:"layer_norm1"`
SelfAttention *VisionSelfAttention
LayerNorm2 *nn.LayerNorm `gguf:"layer_norm2"`
MLP *VisionMLP `gguf:"mlp"`
}
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenState
// self attention
hiddenState = e.LayerNorm1.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
// feed forward
hiddenState = e.LayerNorm2.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
type VisionModelOptions struct {
hiddenSize, numHeads int
imageSize, patchSize int
eps float32
}
type VisionModel struct {
PatchEmbedding *nn.Conv2D `gguf:"patch_embedding"`
PositionEmbedding *nn.Embedding `gguf:"position_embedding"`
PostLayerNorm *nn.LayerNorm `gguf:"post_layernorm"`
Layers []VisionEncoderLayer `gguf:"blk"`
*VisionModelOptions
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize)
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
positions := make([]int32, numPatches)
for i := range positions {
positions[i] = int32(i)
}
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
if err != nil {
panic(err)
}
hiddenState = hiddenState.Add(ctx, m.PositionEmbedding.Forward(ctx, positionIDs))
for _, layer := range m.Layers {
hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions)
}
hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
return hiddenState
}
func newVisionModel(c ml.Config) *VisionModel {
return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
VisionModelOptions: &VisionModelOptions{
hiddenSize: int(c.Uint("vision.embedding_length")),
numHeads: int(c.Uint("vision.attention.head_count")),
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),
eps: c.Float("vision.attention.layer_norm_epsilon"),
},
}
}

View File

@ -0,0 +1,58 @@
package gemma3
import (
"image"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/imageproc"
)
type ImageProcessor struct {
imageSize, patchSize, numChannels int
}
func newImageProcessor(c ml.Config) ImageProcessor {
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),
numChannels: int(c.Uint("vision.num_channels")),
}
}
func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 {
var pixelVals, rVals, gVals, bVals []float32
bounds := img.Bounds()
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
for x := bounds.Min.X; x < bounds.Max.X; x++ {
c := img.At(x, y)
r, g, b, _ := c.RGBA()
rVal := float32(r>>8) / 255.0
gVal := float32(g>>8) / 255.0
bVal := float32(b>>8) / 255.0
rVal = (rVal - mean[0]) / std[0]
gVal = (gVal - mean[1]) / std[1]
bVal = (bVal - mean[2]) / std[2]
rVals = append(rVals, rVal)
gVals = append(gVals, gVal)
bVals = append(bVals, bVal)
}
}
pixelVals = append(pixelVals, rVals...)
pixelVals = append(pixelVals, gVals...)
pixelVals = append(pixelVals, bVals...)
return pixelVals
}
func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
outputSize := image.Point{p.imageSize, p.imageSize}
newImage := imageproc.Composite(img)
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
data := p.pack(newImage, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD)
return data, nil
}

View File

@ -76,14 +76,15 @@ type SelfAttention struct {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@ -96,7 +97,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
}
type MLP struct {

View File

@ -20,14 +20,15 @@ type TextSelfAttention struct {
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@ -40,8 +41,9 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the cache, which are just the self attention layers
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil
}
return key, nil

View File

@ -144,8 +144,6 @@ func (p *ImageProcessor) splitToTiles(img image.Image, numTilesSize image.Point)
return images
}
// remove the "alpha" channel by drawing over a prefilled image
//
// remove the "alpha" channel by drawing over a prefilled image
//
//nolint:unused

View File

@ -1,6 +1,8 @@
package models
import (
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/mllama"
)

View File

@ -4,6 +4,7 @@ import (
"cmp"
"iter"
"log/slog"
"slices"
"strings"
"sync"
@ -18,6 +19,15 @@ const (
SpecialEOS
)
const (
TOKEN_TYPE_NORMAL = iota + 1
TOKEN_TYPE_UNKNOWN
TOKEN_TYPE_CONTROL
TOKEN_TYPE_USER_DEFINED
TOKEN_TYPE_UNUSED
TOKEN_TYPE_BYTE
)
type TextProcessor interface {
Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
@ -27,11 +37,11 @@ type TextProcessor interface {
type Vocabulary struct {
Values []string
Types []uint32
Scores []uint32
Scores []float32
Merges []string
BOS, EOS int32
AddBOS, AddEOS bool
BOS, EOS, EOT int32
AddBOS, AddEOS, AddEOT bool
specialOnce sync.Once
special []string
@ -48,7 +58,7 @@ func (v *Vocabulary) Is(id int32, special Special) bool {
case SpecialBOS:
return id == v.BOS
case SpecialEOS:
return id == v.EOS
return id == v.EOS || id == v.EOT
default:
return false
}
@ -76,7 +86,9 @@ func (v *Vocabulary) Decode(id int32) string {
func (v *Vocabulary) SpecialVocabulary() []string {
v.specialOnce.Do(func() {
for i := range v.Values {
if v.Types[i] == 3 {
if slices.Contains([]int{105, 106}, i) {
v.special = append(v.special, v.Values[i])
} else if v.Types[i] == TOKEN_TYPE_CONTROL {
v.special = append(v.special, v.Values[i])
}
}

246
model/process_text_spm.go Normal file
View File

@ -0,0 +1,246 @@
package model
import (
"iter"
"log/slog"
"strings"
"github.com/dlclark/regexp2"
queue "github.com/emirpasic/gods/v2/queues/priorityqueue"
)
const spmWhitespaceSep = "▁"
func replaceWhitespaceBySeperator(s string) string {
return strings.ReplaceAll(s, " ", spmWhitespaceSep)
}
type SentencePieceModel struct {
maxTokenLen int
pre *regexp2.Regexp
vocab *Vocabulary
}
var _ TextProcessor = (*SentencePieceModel)(nil)
func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
counter := map[int]int{}
var maxTokenLen int
for cnt := range vocab.Types {
switch vocab.Types[cnt] {
case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED:
maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt]))
fallthrough
default:
counter[int(vocab.Types[cnt])] += 1
}
}
slog.Debug("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
"max token len", maxTokenLen)
return SentencePieceModel{
maxTokenLen: maxTokenLen,
pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
vocab: vocab,
}
}
func (spm SentencePieceModel) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special)
}
func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
return func(yield func(string) bool) {
for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {
if !yield(m.String()) {
break
}
}
}
}
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range spm.vocab.SpecialVocabulary() {
// TODO: process special tokens concurrently
id := spm.vocab.Encode(special)
for i := 0; i < len(fragments); i++ {
frag := fragments[i]
if len(frag.ids) > 0 {
continue
}
var middle []fragment
switch i := strings.Index(frag.value, special); {
case i < 0:
middle = append(middle, frag)
case i > 0:
middle = append(middle, fragment{value: frag.value[:i]})
fallthrough
default:
middle = append(middle, fragment{value: special, ids: []int32{id}})
if rest := frag.value[i+len(special):]; rest != "" {
middle = append(middle, fragment{value: rest})
}
}
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
}
}
slog.Debug("fragments", "frags", fragments)
var ids []int32
for _, frag := range fragments {
if len(frag.ids) > 0 {
ids = append(ids, frag.ids...)
continue
}
for split := range spm.split(frag.value) {
split = replaceWhitespaceBySeperator(split)
var sb strings.Builder
sb.Write([]byte(split))
if id := spm.vocab.Encode(sb.String()); id >= 0 {
ids = append(ids, id)
continue
}
runes := []rune(sb.String())
pq := queue.NewWith(func(a, b any) int {
priA := a.(*candidate)
priB := b.(*candidate)
if priA.score > priB.score || (priA.score == priB.score && priA.a < priB.a) {
return -1
}
return 1
})
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
p: r - 1,
n: r + 1,
runes: []rune{runes[r]},
}
}
slog.Debug("tokenizer", "merges", merges)
pairwise := func(a, b int) *candidate {
if a < 0 || b >= len(runes) {
return nil
}
left, right := string(merges[a].runes), string(merges[b].runes)
if id := spm.vocab.Encode(left + right); id >= 0 {
return &candidate{
a: a,
b: b,
score: spm.vocab.Scores[id],
}
}
return nil
}
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
pq.Enqueue(pair)
}
}
pqv := pq.Values()
for _, v := range pqv {
e := v.(*candidate)
slog.Debug("candidate", "candidate", e)
}
for !pq.Empty() {
v, _ := pq.Dequeue()
pair := v.(*candidate)
left, right := merges[pair.a], merges[pair.b]
slog.Debug("pair", "left", left, "right", right)
if len(left.runes) == 0 || len(right.runes) == 0 {
continue
}
if id := spm.vocab.Encode(string(left.runes) + string(right.runes)); id < 0 {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
pq.Enqueue(pair)
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
pq.Enqueue(pair)
}
}
slog.Debug("merges", "merges", merges)
for _, merge := range merges {
if len(merge.runes) > 0 {
if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
ids = append(ids, id)
} else {
slog.Debug("missing token", "token", string(merge.runes))
}
}
}
}
}
if addSpecial && len(ids) > 0 {
if spm.vocab.AddBOS {
if ids[0] == spm.vocab.BOS {
slog.Warn("adding bos token to prompt which already has it", "id", spm.vocab.BOS)
}
slog.Debug("adding bos token to prompt", "id", spm.vocab.BOS)
ids = append([]int32{spm.vocab.BOS}, ids...)
}
if spm.vocab.AddEOS {
if ids[len(ids)-1] == spm.vocab.EOS {
slog.Warn("adding eos token to prompt which already has it", "id", spm.vocab.EOS)
}
slog.Debug("adding eos token to prompt", "id", spm.vocab.EOS)
ids = append(ids, spm.vocab.EOS)
}
}
return ids, nil
}
type candidate struct {
a, b int
score float32
}
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
data := spm.vocab.Decode(id)
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
if _, err := sb.WriteString(data); err != nil {
return "", err
}
}
slog.Debug("decoded", "ids", ids, "text", sb.String())
return sb.String(), nil
}

View File

@ -0,0 +1,118 @@
package model
import (
"log/slog"
"os"
"path/filepath"
"slices"
"testing"
"google.golang.org/protobuf/proto"
"github.com/ollama/ollama/convert/sentencepiece"
)
func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
t.Helper()
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
if err != nil {
t.Fatal(err)
}
var spm sentencepiece.ModelProto
if err := proto.Unmarshal(bts, &spm); err != nil {
t.Fatal(err)
}
preTokenizer := `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`
var v Vocabulary
for _, piece := range spm.GetPieces() {
v.Values = append(v.Values, piece.GetPiece())
v.Scores = append(v.Scores, piece.GetScore())
switch t := piece.GetType(); t {
case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
sentencepiece.ModelProto_SentencePiece_CONTROL,
sentencepiece.ModelProto_SentencePiece_UNUSED,
sentencepiece.ModelProto_SentencePiece_BYTE:
v.Types = append(v.Types, uint32(t))
default:
tt := uint32(sentencepiece.ModelProto_SentencePiece_NORMAL)
// todo parse the special tokens file
// - this will roundtrip correctly but the <start_of_turn> and
// <end_of_turn> tokens aren't processed
v.Types = append(v.Types, tt)
}
}
return NewSentencePieceModel(preTokenizer, &v)
}
func TestSentencePieceEncode(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
slog.SetDefault(logger)
tokenizer := loadSentencePieceVocab(t)
t.Run("basic roundtrip", func(t *testing.T) {
t.Parallel()
cases := []string{
"hello",
"hello ",
"hello ",
" hello",
" hello ",
" hello ",
"hello world",
"请考试我的软件12345",
"你好",
"Hello 你好 world!",
"Special characters: !@#$%^&*()_+-=[]{}|;':\",./<>?",
"Multilingual: 你好 こんにちは Привет Hola مرحبا",
"Numbers and symbols: 123456789 +- */",
"Special tokens: <bos> text <eos>",
"Code snippets: func main() { fmt.Println(\"Hello World\") }",
"Long text: " + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " +
"Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " +
"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris.",
}
for _, want := range cases {
ids, err := tokenizer.Encode(want, true)
if err != nil {
t.Fatal(err)
}
if got, err := tokenizer.Decode(ids); err != nil {
t.Fatal(err)
} else if got != want {
t.Errorf("got %q, want %q [%#v]", got, want, ids)
}
}
})
t.Run("special tokens", func(t *testing.T) {
type candidate struct {
token string
ids []int32
}
cases := []candidate{
{"<bos>", []int32{2}},
{"<eos>", []int32{1}},
}
for _, want := range cases {
ids, err := tokenizer.Encode(want.token, true)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(ids, want.ids) {
t.Errorf("got %#v, want %#v", ids, want.ids)
}
}
})
}

BIN
model/testdata/gemma2/tokenizer.model vendored Normal file

Binary file not shown.

View File

@ -26,6 +26,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
var system []api.Message
isMllama := checkMllamaModelFamily(m)
isGemma3 := checkGemma3ModelFamily(m)
var imageNumTokens int
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
@ -40,7 +41,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
n := len(msgs) - 1
// in reverse, find all messages that fit into context window
for i := n; i >= 0; i-- {
if isMllama && len(msgs[i].Images) > 1 {
if (isMllama || isGemma3) && len(msgs[i].Images) > 1 {
return "", nil, errTooManyImages
}
@ -157,3 +158,12 @@ func checkMllamaModelFamily(m *Model) bool {
}
return false
}
func checkGemma3ModelFamily(m *Model) bool {
for _, arch := range m.Config.ModelFamilies {
if arch == "gemma3" {
return true
}
}
return false
}