commit
aee28501b5
@ -13,8 +13,13 @@ import (
|
||||
)
|
||||
|
||||
type ModelParameters struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
Architectures []string `json:"architectures"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
TextModel TextParameters `json:"text_config"`
|
||||
}
|
||||
|
||||
type TextParameters struct {
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
}
|
||||
|
||||
type AdapterParameters struct {
|
||||
@ -185,6 +190,8 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
||||
conv = &gemmaModel{}
|
||||
case "Gemma2ForCausalLM":
|
||||
conv = &gemma2Model{}
|
||||
case "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration":
|
||||
conv = &gemma3Model{Architecture: p.Architectures[0]}
|
||||
case "Phi3ForCausalLM":
|
||||
conv = &phi3Model{}
|
||||
case "Qwen2ForCausalLM":
|
||||
@ -213,7 +220,14 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
||||
}
|
||||
|
||||
vocabSize := int(p.VocabSize)
|
||||
if vocabSize == 0 {
|
||||
tVocabSize := int(p.TextModel.VocabSize)
|
||||
vocabSize = tVocabSize
|
||||
}
|
||||
|
||||
switch {
|
||||
case vocabSize == 0:
|
||||
slog.Warn("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens))
|
||||
case vocabSize > len(t.Vocabulary.Tokens):
|
||||
slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
|
||||
for i := range vocabSize - len(t.Vocabulary.Tokens) {
|
||||
|
@ -45,7 +45,7 @@ func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
|
||||
func (p *gemmaModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
for _, t := range ts {
|
||||
if strings.HasSuffix(t.Name(), "_norm.weight") {
|
||||
if !strings.HasPrefix(t.Name(), "v.") && strings.HasSuffix(t.Name(), "_norm.weight") {
|
||||
t.SetRepacker(p.addOne)
|
||||
}
|
||||
|
||||
|
142
convert/convert_gemma3.go
Normal file
142
convert/convert_gemma3.go
Normal file
@ -0,0 +1,142 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type gemma3Model struct {
|
||||
gemmaModel
|
||||
Architecture string
|
||||
TextModel struct {
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
} `json:"text_config"`
|
||||
VisionModel struct {
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"` // attention.head_count 16
|
||||
LayerNormEpsilon float32 `json:"layer_norm_eps"` // attention.layer_norm_epsilon 1e-05
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"` // block_count 32
|
||||
HiddenSize uint32 `json:"hidden_size"` // embedding_length 1280
|
||||
IntermediateSize uint32 `json:"intermediate_size"` // feed_forward_length 5120
|
||||
ImageSize uint32 `json:"image_size"` // image_size 560
|
||||
NumChannels uint32 `json:"num_channels"` // num_channels 3
|
||||
PatchSize uint32 `json:"patch_size"` // patch_size 14
|
||||
} `json:"vision_config"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
||||
RopeLocalTheta float32 `json:"rope_local_base_freq"`
|
||||
RopeGlobalTheta float32 `json:"rope_global_base_freq"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
|
||||
}
|
||||
|
||||
const (
|
||||
gemma4BLayerCount = 34
|
||||
gemma12BLayerCount = 48
|
||||
gemma27BLayerCount = 62
|
||||
)
|
||||
|
||||
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma3"
|
||||
|
||||
numBlocks := cmp.Or(p.HiddenLayers, p.TextModel.HiddenLayers)
|
||||
kv["gemma3.block_count"] = numBlocks
|
||||
|
||||
var (
|
||||
numHeads uint32
|
||||
numKVHeads uint32
|
||||
)
|
||||
|
||||
switch numBlocks {
|
||||
case gemma4BLayerCount:
|
||||
numHeads = 8
|
||||
numKVHeads = 4
|
||||
case gemma12BLayerCount:
|
||||
numHeads = 16
|
||||
numKVHeads = 8
|
||||
case gemma27BLayerCount:
|
||||
numHeads = 32
|
||||
numKVHeads = 16
|
||||
default:
|
||||
numHeads = p.NumAttentionHeads
|
||||
numKVHeads = p.NumKeyValueHeads
|
||||
}
|
||||
|
||||
kv["gemma3.attention.head_count"] = numHeads
|
||||
kv["gemma3.attention.head_count_kv"] = numKVHeads
|
||||
|
||||
switch p.Architecture {
|
||||
case "Gemma3ForCausalLM":
|
||||
kv["gemma3.context_length"] = p.MaxPositionEmbeddings
|
||||
kv["gemma3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||
kv["gemma3.attention.key_length"] = p.HeadDim
|
||||
kv["gemma3.attention.value_length"] = p.HeadDim
|
||||
kv["gemma3.attention.sliding_window"] = p.SlidingWindow
|
||||
kv["gemma3.final_logit_softcapping"] = cmp.Or(p.FinalLogitSoftcap, 30)
|
||||
kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0)
|
||||
kv["gemma3.rope.global.freq_base"] = cmp.Or(p.RopeGlobalTheta, 1000000.0)
|
||||
kv["gemma3.embedding_length"] = p.HiddenSize
|
||||
kv["gemma3.feed_forward_length"] = p.IntermediateSize
|
||||
default:
|
||||
kv["gemma3.context_length"] = cmp.Or(p.MaxPositionEmbeddings, 8192)
|
||||
kv["gemma3.embedding_length"] = p.TextModel.HiddenSize
|
||||
kv["gemma3.feed_forward_length"] = p.TextModel.IntermediateSize
|
||||
kv["gemma3.attention.sliding_window"] = p.TextModel.SlidingWindow
|
||||
kv["gemma3.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
||||
kv["gemma3.vision.embedding_length"] = p.VisionModel.HiddenSize
|
||||
kv["gemma3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
|
||||
kv["gemma3.vision.image_size"] = p.VisionModel.ImageSize
|
||||
kv["gemma3.vision.patch_size"] = p.VisionModel.PatchSize
|
||||
kv["gemma3.vision.num_channels"] = cmp.Or(p.VisionModel.NumChannels, 3)
|
||||
kv["gemma3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
|
||||
kv["gemma3.vision.attention.layer_norm_epsilon"] = cmp.Or(p.VisionModel.LayerNormEpsilon, 1e-6)
|
||||
kv["gemma3.attention.key_length"] = cmp.Or(p.TextModel.HeadDim, 256)
|
||||
kv["gemma3.attention.value_length"] = cmp.Or(p.TextModel.HeadDim, 256)
|
||||
}
|
||||
|
||||
if p.MultiModalTokensPerImage > 0 {
|
||||
kv["gemma3.mm.tokens_per_image"] = p.MultiModalTokensPerImage
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *gemma3Model) Replacements() []string {
|
||||
return []string{
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.norm", "output_norm",
|
||||
"vision_tower.vision_model.embeddings", "v",
|
||||
"vision_tower.vision_model", "v",
|
||||
"vision_model.vision_model.embeddings", "v",
|
||||
"vision_model.vision_model", "v",
|
||||
"language_model.", "",
|
||||
"model.layers", "blk",
|
||||
"encoder.layers", "blk",
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.q_norm", "attn_q_norm",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.k_norm", "attn_k_norm",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"self_attn.out_proj", "attn_output",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"post_attention_layernorm", "post_attention_norm",
|
||||
"pre_feedforward_layernorm", "ffn_norm",
|
||||
"post_feedforward_layernorm", "post_ffw_norm",
|
||||
"input_projection_weight", "input_projection.weight",
|
||||
"multi_modal_projector", "mm",
|
||||
}
|
||||
}
|
@ -6,7 +6,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"os"
|
||||
"reflect"
|
||||
"slices"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
@ -15,6 +17,8 @@ import (
|
||||
)
|
||||
|
||||
func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
||||
slog.Debug("using spm vocabulary")
|
||||
|
||||
ast, err := parseAdditionalSpecialTokens(fsys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -43,10 +47,19 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
||||
v.Types = append(v.Types, int32(t))
|
||||
default:
|
||||
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
|
||||
if slices.Contains(ast, piece.GetPiece()) {
|
||||
|
||||
// temporary fix to handle gemma3 broken configs
|
||||
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>"}, piece.GetPiece()) {
|
||||
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
|
||||
}
|
||||
|
||||
for _, t := range ast {
|
||||
if t.Content == piece.GetPiece() {
|
||||
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
v.Types = append(v.Types, tt)
|
||||
}
|
||||
}
|
||||
@ -78,10 +91,16 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
||||
return cmp.Compare(i.id, j.id)
|
||||
})
|
||||
|
||||
n := len(v.Tokens)
|
||||
for i, t := range ts {
|
||||
if t.id != i+n {
|
||||
return nil, fmt.Errorf("invalid token id: %d", t.id)
|
||||
for _, t := range ts {
|
||||
if t.id < len(v.Tokens) {
|
||||
if v.Tokens[t.id] == t.content {
|
||||
slog.Warn("tokenizer", "duplicate token", t.content, "id", t.id)
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("token mismatch: %s != %s at pos [%d]", t.content, v.Tokens[t.id], t.id)
|
||||
}
|
||||
if t.id != len(v.Tokens) {
|
||||
return nil, fmt.Errorf("invalid token id: [%d] as pos [%d]", t.id, len(v.Tokens))
|
||||
}
|
||||
|
||||
v.Tokens = append(v.Tokens, t.content)
|
||||
@ -92,7 +111,15 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
||||
return &v, nil
|
||||
}
|
||||
|
||||
func parseAdditionalSpecialTokens(fsys fs.FS) ([]string, error) {
|
||||
type specialToken struct {
|
||||
Content string `json:"content"`
|
||||
Lstrip bool `json:"lstrip"`
|
||||
Normalized bool `json:"normalized"`
|
||||
Rstrip bool `json:"rstrip"`
|
||||
SingleWord bool `json:"single_word"`
|
||||
}
|
||||
|
||||
func parseAdditionalSpecialTokens(fsys fs.FS) ([]specialToken, error) {
|
||||
f, err := fsys.Open("special_tokens_map.json")
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil, nil
|
||||
@ -102,12 +129,43 @@ func parseAdditionalSpecialTokens(fsys fs.FS) ([]string, error) {
|
||||
defer f.Close()
|
||||
|
||||
var m struct {
|
||||
AdditionalSpecialTokens []string `json:"additional_special_tokens"`
|
||||
AdditionalSpecialTokens any `json:"additional_special_tokens"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(f).Decode(&m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.AdditionalSpecialTokens, nil
|
||||
var ast []specialToken
|
||||
|
||||
switch st := m.AdditionalSpecialTokens.(type) {
|
||||
case []string:
|
||||
for _, s := range st {
|
||||
ast = append(ast, specialToken{Content: s})
|
||||
}
|
||||
case []any:
|
||||
for _, s := range st {
|
||||
// marshal and unmarshal the object to get the special token
|
||||
tMap := s.(map[string]any)
|
||||
data, err := json.Marshal(tMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var token specialToken
|
||||
err = json.Unmarshal(data, &token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ast = append(ast, token)
|
||||
}
|
||||
|
||||
default:
|
||||
slog.Warn("special token", "unknown token", reflect.TypeOf(st))
|
||||
}
|
||||
|
||||
slog.Debug("spm tokenizer", "additional tokens", ast)
|
||||
|
||||
return ast, nil
|
||||
}
|
||||
|
@ -124,6 +124,19 @@ func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
|
||||
return s
|
||||
}
|
||||
|
||||
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
|
||||
r := keyValue(kv, key, &array{})
|
||||
s := make([]float32, r.size)
|
||||
for i := range r.size {
|
||||
s[i] = float32(r.values[i].(float32))
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (kv KV) OllamaEngineRequired() bool {
|
||||
return kv.Architecture() == "gemma3"
|
||||
}
|
||||
|
||||
func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T {
|
||||
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
|
||||
key = kv.Architecture() + "." + key
|
||||
@ -476,7 +489,7 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
||||
// vocab graph
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
)
|
||||
case "gemma", "gemma2":
|
||||
case "gemma", "gemma2", "gemma3":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
|
||||
|
@ -21,9 +21,10 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
|
||||
type Causal struct {
|
||||
DType ml.DType
|
||||
Capacity int32
|
||||
causal bool
|
||||
windowSize int32
|
||||
|
||||
opts CausalOptions
|
||||
|
||||
// config controls mostly backend-specific optimizations
|
||||
config *ml.CacheConfig
|
||||
|
||||
@ -79,7 +80,6 @@ type cellRange struct {
|
||||
|
||||
func NewCausalCache(shift shiftFn) *Causal {
|
||||
return &Causal{
|
||||
causal: true,
|
||||
windowSize: math.MaxInt32,
|
||||
shiftFn: shift,
|
||||
ctxs: make(map[int]ml.Context),
|
||||
@ -90,7 +90,6 @@ func NewCausalCache(shift shiftFn) *Causal {
|
||||
|
||||
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
||||
return &Causal{
|
||||
causal: true,
|
||||
windowSize: windowSize,
|
||||
shiftFn: shift,
|
||||
ctxs: make(map[int]ml.Context),
|
||||
@ -145,6 +144,7 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
|
||||
c.curBatchSize = len(opts.Positions)
|
||||
c.curSequences = opts.Sequences
|
||||
c.curPositions = opts.Positions
|
||||
c.opts.Except = nil
|
||||
|
||||
var err error
|
||||
c.curLoc, err = c.findStartLoc()
|
||||
@ -235,9 +235,10 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
||||
mask := make([]float32, batchSize*length)
|
||||
|
||||
for i := range c.curBatchSize {
|
||||
enabled := !slices.Contains(c.opts.Except, i)
|
||||
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
||||
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
||||
(c.causal && c.cells[j].pos > c.curPositions[i]) ||
|
||||
(enabled && c.cells[j].pos > c.curPositions[i]) ||
|
||||
c.cells[j].pos < c.curPositions[i]-c.windowSize {
|
||||
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||
}
|
||||
@ -404,15 +405,16 @@ func (c *Causal) SetLayer(layer int) {
|
||||
c.curLayer = layer
|
||||
}
|
||||
|
||||
// SetCausal enables or disables causal mask generation for subsequent calls to Get.
|
||||
// This state carries over to future forward passes. The default value is true.
|
||||
//
|
||||
// ctx may be set to nil if this is called from outside of a forward pass, for
|
||||
// example, when initializing the cache.
|
||||
func (c *Causal) SetCausal(ctx ml.Context, causal bool) {
|
||||
if c.causal != causal {
|
||||
c.causal = causal
|
||||
type CausalOptions struct {
|
||||
// Enabled controls whether the causal mask is generated for a particular index in a batch
|
||||
Except []int
|
||||
}
|
||||
|
||||
// SetCausal disables causal mask generation for a particular range of indicies in
|
||||
// the current batch for subsequent calls to Get. The state resets for the next forward pass.
|
||||
func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
|
||||
if !slices.Equal(c.opts.Except, opts.Except) {
|
||||
c.opts = opts
|
||||
if ctx != nil {
|
||||
var err error
|
||||
c.curMask, err = c.buildMask(ctx)
|
||||
|
@ -441,11 +441,19 @@ func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim uint32, base, scale float32) ml.Tensor {
|
||||
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
@ -495,6 +503,10 @@ func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
33
llama/patches/0020-ollama-debug-tensor.patch
Normal file
33
llama/patches/0020-ollama-debug-tensor.patch
Normal file
@ -0,0 +1,33 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Michael Yang <mxyng@pm.me>
|
||||
Date: Sun, 9 Mar 2025 14:44:16 -0700
|
||||
Subject: [PATCH] ollama debug tensor
|
||||
|
||||
---
|
||||
ggml/src/ggml-cpu/ggml-cpu.c | 6 ++++++
|
||||
1 file changed, 6 insertions(+)
|
||||
|
||||
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
|
||||
index 2f606d82..ec60e8fc 100644
|
||||
--- a/ggml/src/ggml-cpu/ggml-cpu.c
|
||||
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
|
||||
@@ -11,6 +11,8 @@
|
||||
#include "ggml-threading.h"
|
||||
#include "ggml.h"
|
||||
|
||||
+#include "ollama-debug.h"
|
||||
+
|
||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||
#include <malloc.h> // using malloc.h with MSC/MINGW
|
||||
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
|
||||
@@ -14103,6 +14105,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
|
||||
ggml_compute_forward(¶ms, node);
|
||||
|
||||
+#ifdef OLLAMA_DEBUG
|
||||
+ ollama_debug(node, true);
|
||||
+#endif
|
||||
+
|
||||
if (state->ith == 0 && cplan->abort_callback &&
|
||||
cplan->abort_callback(cplan->abort_callback_data)) {
|
||||
atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed);
|
@ -271,7 +271,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
|
||||
var llamaModel *llama.Model
|
||||
var textProcessor model.TextProcessor
|
||||
if envconfig.NewEngine() {
|
||||
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
|
||||
textProcessor, err = model.NewTextProcessor(modelPath)
|
||||
if err != nil {
|
||||
// To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner
|
||||
|
@ -19,6 +19,7 @@ type Config interface {
|
||||
|
||||
Strings(string, ...[]string) []string
|
||||
Uints(string, ...[]uint32) []uint32
|
||||
Floats(string, ...[]float32) []float32
|
||||
}
|
||||
|
||||
type Backend interface {
|
||||
@ -134,8 +135,10 @@ type Tensor interface {
|
||||
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
||||
Scale(ctx Context, s float64) Tensor
|
||||
|
||||
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, base, scale float32) Tensor
|
||||
|
||||
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
|
||||
|
||||
Tanh(ctx Context) Tensor
|
||||
GELU(ctx Context) Tensor
|
||||
@ -145,6 +148,7 @@ type Tensor interface {
|
||||
View(ctx Context, offset int, shape ...int) Tensor
|
||||
Permute(ctx Context, shape ...int) Tensor
|
||||
Contiguous(ctx Context) Tensor
|
||||
Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor
|
||||
|
||||
Pad(ctx Context, shape ...int) Tensor
|
||||
Unpad(ctx Context, shape ...int) Tensor
|
||||
|
@ -240,11 +240,22 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
switch {
|
||||
case contains(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"):
|
||||
createTensor(tensor{source: t}, input.bts)
|
||||
if _, ok := meta.Tensors().GroupLayers()["output"]; !ok && t.Name == "token_embd.weight" {
|
||||
createTensor(tensor{source: t, target: "output.weight"}, output.bts)
|
||||
}
|
||||
case contains(t.Name, "cls", "output", "output_norm"):
|
||||
createTensor(tensor{source: t}, output.bts)
|
||||
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
|
||||
// TODO: assign vision tensors to the gpu if possible
|
||||
createTensor(tensor{source: t}, input.bts)
|
||||
createTensor(tensor{source: t}, output.bts)
|
||||
case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"):
|
||||
// these tensors should be repeated per layer
|
||||
for i, layer := range layers {
|
||||
createTensor(tensor{
|
||||
source: t,
|
||||
target: "blk." + strconv.Itoa(i) + "." + t.Name,
|
||||
}, layer.bts)
|
||||
}
|
||||
default:
|
||||
layerIndex := -1
|
||||
if fields := strings.FieldsFunc(t.Name, func(r rune) bool { return !unicode.IsNumber(r) }); len(fields) > 0 {
|
||||
@ -256,14 +267,8 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
if layerIndex >= 0 {
|
||||
createTensor(tensor{source: t}, layers[layerIndex].bts)
|
||||
} else {
|
||||
// this is a repeating tensor that doesn't explicitly associated with a layer so
|
||||
// duplicate it for each layer
|
||||
for i, layer := range layers {
|
||||
createTensor(tensor{
|
||||
source: t,
|
||||
target: "blk." + strconv.Itoa(i) + "." + t.Name,
|
||||
}, layer.bts)
|
||||
}
|
||||
// load all other tensors on the cpu
|
||||
createTensor(tensor{source: t}, input.bts)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -352,7 +357,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
|
||||
if C.ggml_backend_is_cpu(b) {
|
||||
// set number of threads for cpu backend
|
||||
C.ggml_backend_cpu_set_n_threads(b, C.int(params.NumThreads))
|
||||
C.ggml_backend_cpu_set_n_threads(b, C.int(Threads(params.NumThreads)))
|
||||
}
|
||||
}
|
||||
|
||||
@ -893,10 +898,13 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
}
|
||||
|
||||
const (
|
||||
ropeTypeNorm C.int = iota
|
||||
ropeTypeNorm C.int = 0
|
||||
ropeTypeNeox C.int = 2
|
||||
ropeTypeMrope C.int = 8
|
||||
ropeTypeVision C.int = 24
|
||||
)
|
||||
|
||||
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
|
||||
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
|
||||
if ropeFactors == nil {
|
||||
ropeFactors = &Tensor{b: t.b}
|
||||
}
|
||||
@ -911,8 +919,8 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
|
||||
t: C.ggml_rope_ext(
|
||||
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
|
||||
C.int(ropeDim),
|
||||
131072, // YaRN n_ctx_train
|
||||
ropeTypeNorm, // ROPE_TYPE_NORM
|
||||
C.int(ropeType),
|
||||
131072, // YaRN n_ctx_train
|
||||
C.float(ropeBase),
|
||||
C.float(ropeScale),
|
||||
0., // YaRN ext_factor
|
||||
@ -944,6 +952,27 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_pool_2d(ctx.(*Context).ctx, t.t, C.GGML_OP_POOL_AVG, C.int(k), C.int(k), C.int(s), C.int(s), C.float(p), C.float(p)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
|
||||
var tt *C.struct_ggml_tensor
|
||||
switch len(strides) {
|
||||
case 0:
|
||||
tt = C.ggml_set_1d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset))
|
||||
case 1:
|
||||
tt = C.ggml_set_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0]))
|
||||
default:
|
||||
panic("unsupported number of dimensions")
|
||||
}
|
||||
|
||||
return &Tensor{b: t.b, t: tt}
|
||||
}
|
||||
|
||||
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
|
||||
var kqMask *C.struct_ggml_tensor
|
||||
if mask != nil {
|
||||
|
11
ml/backend/ggml/ggml/include/ollama-debug.h
vendored
Normal file
11
ml/backend/ggml/ggml/include/ollama-debug.h
vendored
Normal file
@ -0,0 +1,11 @@
|
||||
#include "ggml.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
void ollama_debug(const struct ggml_tensor *tensor, bool verbose);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
6
ml/backend/ggml/ggml/src/ggml-cpu/cpu_debug.go
Normal file
6
ml/backend/ggml/ggml/src/ggml-cpu/cpu_debug.go
Normal file
@ -0,0 +1,6 @@
|
||||
//go:build debug
|
||||
|
||||
package cpu
|
||||
|
||||
// #cgo CPPFLAGS: -DOLLAMA_DEBUG
|
||||
import "C"
|
6
ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c
vendored
6
ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c
vendored
@ -11,6 +11,8 @@
|
||||
#include "ggml-threading.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#include "ollama-debug.h"
|
||||
|
||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||
#include <malloc.h> // using malloc.h with MSC/MINGW
|
||||
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
|
||||
@ -14103,6 +14105,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
|
||||
ggml_compute_forward(¶ms, node);
|
||||
|
||||
#ifdef OLLAMA_DEBUG
|
||||
ollama_debug(node, true);
|
||||
#endif
|
||||
|
||||
if (state->ith == 0 && cplan->abort_callback &&
|
||||
cplan->abort_callback(cplan->abort_callback_data)) {
|
||||
atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed);
|
||||
|
115
ml/backend/ggml/ggml/src/ollama-debug.c
vendored
Normal file
115
ml/backend/ggml/ggml/src/ollama-debug.c
vendored
Normal file
@ -0,0 +1,115 @@
|
||||
#include <string.h>
|
||||
|
||||
#include "ollama-debug.h"
|
||||
|
||||
static int mul(int64_t *dims, int ndims) {
|
||||
int result = 1;
|
||||
for (int i = 0; i < ndims; i++) {
|
||||
result *= dims[i];
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static void repeat(char c, int n) {
|
||||
for (int i = 0; i < n; i++) {
|
||||
fprintf(stderr, "%c", c);
|
||||
}
|
||||
}
|
||||
|
||||
static void print_tensor(const void *tensor, void (*cb)(const void *, int),
|
||||
int shape,
|
||||
int64_t *dims, int ndims, int stride,
|
||||
int nitems, int pad) {
|
||||
fprintf(stderr, "[");
|
||||
for (int i = 0; i < dims[0]; i++) {
|
||||
if (i >= nitems && i < dims[0] - nitems) {
|
||||
fprintf(stderr, "... (%lld more), ", dims[0] - 2 * nitems);
|
||||
int skip = dims[0] - 2 * nitems;
|
||||
if (ndims > 1) {
|
||||
stride += mul(dims + 1, ndims - 1) * skip;
|
||||
repeat('\n', ndims - 1);
|
||||
repeat(' ', shape - ndims + 1 + pad);
|
||||
}
|
||||
i += skip - 1;
|
||||
} else if (ndims > 1) {
|
||||
print_tensor(tensor, cb, shape, dims + 1, ndims - 1, stride,
|
||||
nitems, pad);
|
||||
stride += mul(dims + 1, ndims - 1);
|
||||
if (i < dims[0] - 1) {
|
||||
fprintf(stderr, ", ");
|
||||
repeat('\n', ndims - 1);
|
||||
repeat(' ', shape - ndims + 1 + pad);
|
||||
}
|
||||
} else {
|
||||
cb(tensor, stride + i);
|
||||
if (i < dims[0] - 1) {
|
||||
fprintf(stderr, ", ");
|
||||
}
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "]");
|
||||
}
|
||||
|
||||
static void print_tensor_f16(const void *tensor, int i) {
|
||||
float value = ggml_fp16_to_fp32(((const ggml_fp16_t *)tensor)[i]);
|
||||
fprintf(stderr, "%s%f", value < 0 ? "" : " ", value);
|
||||
}
|
||||
|
||||
static void print_tensor_f32(const void *tensor, int i) {
|
||||
float value = ((const float *)tensor)[i];
|
||||
fprintf(stderr, "%s%f", value < 0 ? "" : " ", value);
|
||||
}
|
||||
|
||||
static void print_tensor_i32(const void *tensor, int i) {
|
||||
int32_t value = ((const int32_t *)tensor)[i];
|
||||
fprintf(stderr, "%s%d", value < 0 ? "" : " ", value);
|
||||
}
|
||||
|
||||
static void ollama_debug_tensor(const struct ggml_tensor *tensor, bool verbose, const char *prefix, int indent) {
|
||||
fprintf(stderr, "%s%s %s (%s): [%lld %lld %lld %lld]\n", prefix, tensor->name,
|
||||
ggml_op_name(tensor->op), ggml_type_name(tensor->type), tensor->ne[0],
|
||||
tensor->ne[1], tensor->ne[2], tensor->ne[3]);
|
||||
|
||||
if (!verbose) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < indent; i++) {
|
||||
fprintf(stderr, " ");
|
||||
}
|
||||
|
||||
switch (tensor->type) {
|
||||
case GGML_TYPE_F16:
|
||||
print_tensor(ggml_get_data(tensor), print_tensor_f16, ggml_n_dims(tensor),
|
||||
(int64_t *)tensor->ne, ggml_n_dims(tensor), 0, 3, indent);
|
||||
break;
|
||||
case GGML_TYPE_F32:
|
||||
print_tensor(ggml_get_data(tensor), print_tensor_f32, ggml_n_dims(tensor),
|
||||
(int64_t *)tensor->ne, ggml_n_dims(tensor), 0, 3, indent);
|
||||
break;
|
||||
case GGML_TYPE_I32:
|
||||
print_tensor(ggml_get_data(tensor), print_tensor_i32, ggml_n_dims(tensor),
|
||||
(int64_t *)tensor->ne, ggml_n_dims(tensor), 0, 3, indent);
|
||||
break;
|
||||
default:
|
||||
fprintf(stderr, "<unsupported type>\n");
|
||||
return;
|
||||
}
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
void ollama_debug(const struct ggml_tensor *tensor, bool verbose) {
|
||||
ollama_debug_tensor(tensor, verbose, ">>> ", 4);
|
||||
|
||||
for (int i = 0; i < GGML_MAX_SRC && tensor->src[i] != NULL; ++i) {
|
||||
char src[8];
|
||||
const int n = snprintf(src, sizeof(src), " src%d ", i);
|
||||
if (n >= sizeof(src)) {
|
||||
src[sizeof(src) - 1] = '\0';
|
||||
}
|
||||
|
||||
ollama_debug_tensor(tensor->src[i], verbose, src, 4);
|
||||
}
|
||||
}
|
7
ml/backend/ggml/threads.go
Normal file
7
ml/backend/ggml/threads.go
Normal file
@ -0,0 +1,7 @@
|
||||
//go:build !debug
|
||||
|
||||
package ggml
|
||||
|
||||
func Threads(n int) int {
|
||||
return n
|
||||
}
|
7
ml/backend/ggml/threads_debug.go
Normal file
7
ml/backend/ggml/threads_debug.go
Normal file
@ -0,0 +1,7 @@
|
||||
//go:build debug
|
||||
|
||||
package ggml
|
||||
|
||||
func Threads(_ int) int {
|
||||
return 1
|
||||
}
|
220
model/models/gemma2/model.go
Normal file
220
model/models/gemma2/model.go
Normal file
@ -0,0 +1,220 @@
|
||||
package gemma2
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
attnKeyLen, attnValLen int
|
||||
eps, ropeBase, ropeScale float32
|
||||
attnLogitSoftcap float32
|
||||
finalLogitSoftcap float32
|
||||
largeModelScaling bool
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePieceModel
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"` // just set to token_embd?
|
||||
|
||||
*Options
|
||||
}
|
||||
|
||||
const (
|
||||
gemma27BLayerCount = 46
|
||||
)
|
||||
|
||||
func New(c ml.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Uints("tokenizer.ggml.token_type"),
|
||||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
||||
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
},
|
||||
),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: &Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
attnKeyLen: int(c.Uint("attention.key_length")),
|
||||
attnValLen: int(c.Uint("attention.value_length")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base", 10000.0),
|
||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
||||
attnLogitSoftcap: c.Float("attn_logit_softcapping"),
|
||||
finalLogitSoftcap: c.Float("final_logit_softcapping"),
|
||||
},
|
||||
}
|
||||
|
||||
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
|
||||
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
|
||||
m.Cache.SetConfig(ml.CacheConfig{})
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
type SelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
ropeType := uint32(2)
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
if opts.largeModelScaling {
|
||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||
} else {
|
||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
|
||||
}
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||
|
||||
cache.Put(ctx, k, v)
|
||||
k, v, mask := cache.Get(ctx)
|
||||
|
||||
q = q.Permute(ctx, 0, 2, 1, 3)
|
||||
k = k.Permute(ctx, 0, 2, 1, 3)
|
||||
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
kq := k.Mulmat(ctx, q)
|
||||
|
||||
// logit softcap
|
||||
kq = kq.Scale(ctx, 1.0/float64(opts.attnLogitSoftcap))
|
||||
kq = kq.Tanh(ctx)
|
||||
kq = kq.Scale(ctx, float64(opts.attnLogitSoftcap))
|
||||
|
||||
kq = kq.Add(ctx, mask)
|
||||
kq = kq.Softmax(ctx)
|
||||
|
||||
kqv := v.Mulmat(ctx, kq)
|
||||
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
}
|
||||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
SelfAttention *SelfAttention
|
||||
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP *MLP
|
||||
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
|
||||
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
|
||||
// In the final layer (outputs != nil), optimize by pruning to just the token positions
|
||||
// we need logits for.
|
||||
if outputs != nil {
|
||||
hiddenState = hiddenState.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
|
||||
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
||||
|
||||
if len(m.Layers) == gemma27BLayerCount {
|
||||
m.Options.largeModelScaling = true
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
cacheType := i % 2
|
||||
m.Cache.SetLayer(i)
|
||||
wc := m.Cache.(*kvcache.WrapperCache)
|
||||
wc.SetLayerType(cacheType)
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
lastLayerOutputs = outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
hiddenState = m.Output.Forward(ctx, hiddenState)
|
||||
|
||||
// final logit softcap
|
||||
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
|
||||
hiddenState = hiddenState.Tanh(ctx)
|
||||
hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap))
|
||||
return hiddenState.Rows(ctx, outputs), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("gemma2", New)
|
||||
}
|
173
model/models/gemma3/model.go
Normal file
173
model/models/gemma3/model.go
Normal file
@ -0,0 +1,173 @@
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"hash/fnv"
|
||||
"image"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePieceModel
|
||||
|
||||
*VisionModel `gguf:"v,vision"`
|
||||
*TextModel
|
||||
|
||||
*MultiModalProjector `gguf:"mm"`
|
||||
|
||||
ImageProcessor
|
||||
}
|
||||
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
type MultiModalProjector struct {
|
||||
SoftEmbNorm *nn.RMSNorm `gguf:"mm_soft_emb_norm"`
|
||||
InputProjection *nn.Linear `gguf:"mm_input_projection"`
|
||||
|
||||
tokensPerImage int
|
||||
}
|
||||
|
||||
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, imageSize, patchSize int, eps float32) ml.Tensor {
|
||||
l := visionOutputs.Dim(0)
|
||||
|
||||
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
patchesPerImage := imageSize / patchSize
|
||||
visionOutputs = visionOutputs.Reshape(ctx, patchesPerImage, patchesPerImage, l)
|
||||
|
||||
kernelSize := patchesPerImage / int(math.Sqrt(float64(p.tokensPerImage)))
|
||||
visionOutputs = visionOutputs.AvgPool2D(ctx, kernelSize, kernelSize, 0)
|
||||
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0)*visionOutputs.Dim(1), l)
|
||||
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
|
||||
|
||||
// TODO: inputProjection must be transposed since they're incompatible with visionOutputs
|
||||
visionOutputs = p.InputProjection.Weight.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mulmat(ctx, visionOutputs)
|
||||
return visionOutputs
|
||||
}
|
||||
|
||||
func New(c ml.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Uints("tokenizer.ggml.token_type"),
|
||||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
EOS: int32(1),
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOT: int32(106),
|
||||
AddEOT: c.Bool("tokenizer.ggml.add_eot_token", false),
|
||||
},
|
||||
),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
TextModel: newTextModel(c),
|
||||
MultiModalProjector: &MultiModalProjector{
|
||||
tokensPerImage: int(c.Uint("mm_tokens_per_image", 256)),
|
||||
},
|
||||
}
|
||||
|
||||
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
|
||||
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
||||
image, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f32s, err := m.ImageProcessor.ProcessImage(image)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.numChannels,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
||||
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
|
||||
return visionOutputs, nil
|
||||
}
|
||||
|
||||
type imageToken struct {
|
||||
embedding ml.Tensor
|
||||
index int
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
fnvHash := fnv.New64a()
|
||||
|
||||
for _, inp := range inputs {
|
||||
if inp.Multimodal == nil {
|
||||
result = append(result, inp)
|
||||
} else {
|
||||
imageInputs := []input.Input{
|
||||
{Token: 108}, // "\n\n"
|
||||
{Token: 255999}, // "<start_of_image>""
|
||||
}
|
||||
result = append(result, imageInputs...)
|
||||
|
||||
// add image embeddings
|
||||
inputMultimodal := inp.Multimodal.(ml.Tensor)
|
||||
|
||||
for i := range inputMultimodal.Dim(1) {
|
||||
fnvHash.Reset()
|
||||
binary.Write(fnvHash, binary.NativeEndian, inp.MultimodalHash)
|
||||
fnvHash.Write([]byte{byte(i)})
|
||||
|
||||
imageToken := imageToken{embedding: inputMultimodal, index: i}
|
||||
result = append(result, input.Input{Multimodal: imageToken, MultimodalHash: fnvHash.Sum64()})
|
||||
}
|
||||
|
||||
result = append(result,
|
||||
input.Input{Token: 256000}, // <end_of_image>
|
||||
input.Input{Token: 108}, // "\n\n"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("gemma3", New)
|
||||
}
|
254
model/models/gemma3/model_text.go
Normal file
254
model/models/gemma3/model_text.go
Normal file
@ -0,0 +1,254 @@
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type TextOptions struct {
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
attnKeyLen, attnValLen int
|
||||
eps, ropeScale float32
|
||||
ropeLocalBase, ropeGlobalBase float32
|
||||
finalLogitSoftcap float32
|
||||
largeModelScaling bool
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
model.Base
|
||||
model.SentencePieceModel
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []TextLayer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
*TextOptions
|
||||
}
|
||||
|
||||
const (
|
||||
gemmaGlobalCacheCount = 6
|
||||
gemma27BLayerCount = 62
|
||||
)
|
||||
|
||||
const (
|
||||
cacheTypeSWA = iota
|
||||
cacheTypeCausal
|
||||
)
|
||||
|
||||
func newTextModel(c ml.Config) *TextModel {
|
||||
numBlocks := int(c.Uint("block_count"))
|
||||
|
||||
m := TextModel{
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Uints("tokenizer.ggml.token_type"),
|
||||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
||||
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
},
|
||||
),
|
||||
Layers: make([]TextLayer, numBlocks),
|
||||
TextOptions: &TextOptions{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
attnKeyLen: int(c.Uint("attention.key_length", 256)),
|
||||
attnValLen: int(c.Uint("attention.value_length", 256)),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
||||
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
|
||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
||||
finalLogitSoftcap: c.Float("final_logit_softcapping", 30.0),
|
||||
},
|
||||
}
|
||||
|
||||
if numBlocks == gemma27BLayerCount {
|
||||
m.largeModelScaling = true
|
||||
}
|
||||
|
||||
return &m
|
||||
}
|
||||
|
||||
type TextSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
ropeType := uint32(2)
|
||||
|
||||
ropeBase := opts.ropeLocalBase
|
||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||
ropeBase = opts.ropeGlobalBase
|
||||
}
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
|
||||
|
||||
if opts.largeModelScaling {
|
||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||
} else {
|
||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
|
||||
}
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||
|
||||
scaleFactor := 1.0
|
||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
ropeBase := m.TextOptions.ropeLocalBase
|
||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||
ropeBase = m.TextOptions.ropeGlobalBase
|
||||
}
|
||||
|
||||
return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
}
|
||||
|
||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
type TextLayer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
SelfAttention *TextSelfAttention
|
||||
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP *TextMLP
|
||||
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
|
||||
}
|
||||
|
||||
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positionIDs, cache, opts)
|
||||
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
|
||||
// In the final layer (outputs != nil), optimize by pruning to just the token positions
|
||||
// we need logits for.
|
||||
if outputs != nil {
|
||||
hiddenState = hiddenState.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
|
||||
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex) []int {
|
||||
var embedding ml.Tensor
|
||||
var src, dst, length int
|
||||
var except []int
|
||||
|
||||
for _, image := range multimodal {
|
||||
imageToken := image.Multimodal.(imageToken)
|
||||
imageSrc := imageToken.index
|
||||
imageDst := image.Index
|
||||
|
||||
if embedding == nil {
|
||||
embedding = imageToken.embedding
|
||||
src = imageSrc
|
||||
dst = imageDst
|
||||
length = 1
|
||||
} else if embedding == imageToken.embedding && imageSrc+1 == src && imageDst+1 == dst {
|
||||
src = imageSrc
|
||||
dst = imageDst
|
||||
length++
|
||||
} else if embedding == imageToken.embedding && src+length == imageSrc && dst+length == imageDst {
|
||||
length++
|
||||
} else {
|
||||
visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
|
||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
|
||||
|
||||
embedding = imageToken.embedding
|
||||
src = imageSrc
|
||||
dst = imageDst
|
||||
length = 1
|
||||
}
|
||||
|
||||
except = append(except, imageDst)
|
||||
}
|
||||
|
||||
if embedding != nil {
|
||||
visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
|
||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
|
||||
}
|
||||
|
||||
return except
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
|
||||
|
||||
except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
// gemma alternates between the sliding window (local) and causal (global)
|
||||
// kv cache every 6 layers
|
||||
cacheType := cacheTypeSWA
|
||||
if (i+1)%gemmaGlobalCacheCount == 0 {
|
||||
cacheType = cacheTypeCausal
|
||||
}
|
||||
cache.SetLayer(i)
|
||||
wc := cache.(*kvcache.WrapperCache)
|
||||
wc.SetLayerType(cacheType)
|
||||
|
||||
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
|
||||
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
|
||||
}
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
lastLayerOutputs = outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
hiddenState = m.Output.Forward(ctx, hiddenState)
|
||||
|
||||
// final logit softcap
|
||||
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextOptions.finalLogitSoftcap))
|
||||
hiddenState = hiddenState.Tanh(ctx)
|
||||
return hiddenState.Scale(ctx, float64(m.TextOptions.finalLogitSoftcap))
|
||||
}
|
127
model/models/gemma3/model_vision.go
Normal file
127
model/models/gemma3/model_vision.go
Normal file
@ -0,0 +1,127 @@
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
var batchSize int = 1
|
||||
|
||||
type VisionSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
|
||||
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
|
||||
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
|
||||
hiddenState = sa.Output.Forward(ctx, attention)
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
type VisionMLP struct {
|
||||
FC1 *nn.Linear `gguf:"fc1"`
|
||||
FC2 *nn.Linear `gguf:"fc2"`
|
||||
}
|
||||
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
hiddenState = mlp.FC1.Forward(ctx, hiddenState).GELU(ctx)
|
||||
hiddenState = mlp.FC2.Forward(ctx, hiddenState)
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
type VisionEncoderLayer struct {
|
||||
LayerNorm1 *nn.LayerNorm `gguf:"layer_norm1"`
|
||||
SelfAttention *VisionSelfAttention
|
||||
|
||||
LayerNorm2 *nn.LayerNorm `gguf:"layer_norm2"`
|
||||
MLP *VisionMLP `gguf:"mlp"`
|
||||
}
|
||||
|
||||
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
// self attention
|
||||
hiddenState = e.LayerNorm1.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
// feed forward
|
||||
hiddenState = e.LayerNorm2.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
type VisionModelOptions struct {
|
||||
hiddenSize, numHeads int
|
||||
imageSize, patchSize int
|
||||
eps float32
|
||||
}
|
||||
|
||||
type VisionModel struct {
|
||||
PatchEmbedding *nn.Conv2D `gguf:"patch_embedding"`
|
||||
PositionEmbedding *nn.Embedding `gguf:"position_embedding"`
|
||||
PostLayerNorm *nn.LayerNorm `gguf:"post_layernorm"`
|
||||
|
||||
Layers []VisionEncoderLayer `gguf:"blk"`
|
||||
|
||||
*VisionModelOptions
|
||||
}
|
||||
|
||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||
numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize)
|
||||
|
||||
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
||||
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
|
||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
positions := make([]int32, numPatches)
|
||||
for i := range positions {
|
||||
positions[i] = int32(i)
|
||||
}
|
||||
|
||||
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, m.PositionEmbedding.Forward(ctx, positionIDs))
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions)
|
||||
}
|
||||
|
||||
hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
func newVisionModel(c ml.Config) *VisionModel {
|
||||
return &VisionModel{
|
||||
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
|
||||
VisionModelOptions: &VisionModelOptions{
|
||||
hiddenSize: int(c.Uint("vision.embedding_length")),
|
||||
numHeads: int(c.Uint("vision.attention.head_count")),
|
||||
|
||||
imageSize: int(c.Uint("vision.image_size")),
|
||||
patchSize: int(c.Uint("vision.patch_size")),
|
||||
|
||||
eps: c.Float("vision.attention.layer_norm_epsilon"),
|
||||
},
|
||||
}
|
||||
}
|
58
model/models/gemma3/process_image.go
Normal file
58
model/models/gemma3/process_image.go
Normal file
@ -0,0 +1,58 @@
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"image"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/imageproc"
|
||||
)
|
||||
|
||||
type ImageProcessor struct {
|
||||
imageSize, patchSize, numChannels int
|
||||
}
|
||||
|
||||
func newImageProcessor(c ml.Config) ImageProcessor {
|
||||
return ImageProcessor{
|
||||
imageSize: int(c.Uint("vision.image_size")),
|
||||
patchSize: int(c.Uint("vision.patch_size")),
|
||||
numChannels: int(c.Uint("vision.num_channels")),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 {
|
||||
var pixelVals, rVals, gVals, bVals []float32
|
||||
|
||||
bounds := img.Bounds()
|
||||
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
|
||||
for x := bounds.Min.X; x < bounds.Max.X; x++ {
|
||||
c := img.At(x, y)
|
||||
r, g, b, _ := c.RGBA()
|
||||
rVal := float32(r>>8) / 255.0
|
||||
gVal := float32(g>>8) / 255.0
|
||||
bVal := float32(b>>8) / 255.0
|
||||
|
||||
rVal = (rVal - mean[0]) / std[0]
|
||||
gVal = (gVal - mean[1]) / std[1]
|
||||
bVal = (bVal - mean[2]) / std[2]
|
||||
|
||||
rVals = append(rVals, rVal)
|
||||
gVals = append(gVals, gVal)
|
||||
bVals = append(bVals, bVal)
|
||||
}
|
||||
}
|
||||
|
||||
pixelVals = append(pixelVals, rVals...)
|
||||
pixelVals = append(pixelVals, gVals...)
|
||||
pixelVals = append(pixelVals, bVals...)
|
||||
|
||||
return pixelVals
|
||||
}
|
||||
|
||||
func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
|
||||
outputSize := image.Point{p.imageSize, p.imageSize}
|
||||
newImage := imageproc.Composite(img)
|
||||
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
|
||||
|
||||
data := p.pack(newImage, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD)
|
||||
return data, nil
|
||||
}
|
@ -76,14 +76,15 @@ type SelfAttention struct {
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
ropeType := uint32(0)
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
@ -96,7 +97,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
|
@ -20,14 +20,15 @@ type TextSelfAttention struct {
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
ropeType := uint32(0)
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
@ -40,8 +41,9 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
// This will only get called for layers in the cache, which are just the self attention layers
|
||||
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
|
||||
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil
|
||||
}
|
||||
|
||||
return key, nil
|
||||
|
@ -144,8 +144,6 @@ func (p *ImageProcessor) splitToTiles(img image.Image, numTilesSize image.Point)
|
||||
return images
|
||||
}
|
||||
|
||||
// remove the "alpha" channel by drawing over a prefilled image
|
||||
//
|
||||
// remove the "alpha" channel by drawing over a prefilled image
|
||||
//
|
||||
//nolint:unused
|
||||
|
@ -1,6 +1,8 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
_ "github.com/ollama/ollama/model/models/gemma2"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||
_ "github.com/ollama/ollama/model/models/llama"
|
||||
_ "github.com/ollama/ollama/model/models/mllama"
|
||||
)
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"cmp"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@ -18,6 +19,15 @@ const (
|
||||
SpecialEOS
|
||||
)
|
||||
|
||||
const (
|
||||
TOKEN_TYPE_NORMAL = iota + 1
|
||||
TOKEN_TYPE_UNKNOWN
|
||||
TOKEN_TYPE_CONTROL
|
||||
TOKEN_TYPE_USER_DEFINED
|
||||
TOKEN_TYPE_UNUSED
|
||||
TOKEN_TYPE_BYTE
|
||||
)
|
||||
|
||||
type TextProcessor interface {
|
||||
Encode(s string, addSpecial bool) ([]int32, error)
|
||||
Decode([]int32) (string, error)
|
||||
@ -27,11 +37,11 @@ type TextProcessor interface {
|
||||
type Vocabulary struct {
|
||||
Values []string
|
||||
Types []uint32
|
||||
Scores []uint32
|
||||
Scores []float32
|
||||
Merges []string
|
||||
|
||||
BOS, EOS int32
|
||||
AddBOS, AddEOS bool
|
||||
BOS, EOS, EOT int32
|
||||
AddBOS, AddEOS, AddEOT bool
|
||||
|
||||
specialOnce sync.Once
|
||||
special []string
|
||||
@ -48,7 +58,7 @@ func (v *Vocabulary) Is(id int32, special Special) bool {
|
||||
case SpecialBOS:
|
||||
return id == v.BOS
|
||||
case SpecialEOS:
|
||||
return id == v.EOS
|
||||
return id == v.EOS || id == v.EOT
|
||||
default:
|
||||
return false
|
||||
}
|
||||
@ -76,7 +86,9 @@ func (v *Vocabulary) Decode(id int32) string {
|
||||
func (v *Vocabulary) SpecialVocabulary() []string {
|
||||
v.specialOnce.Do(func() {
|
||||
for i := range v.Values {
|
||||
if v.Types[i] == 3 {
|
||||
if slices.Contains([]int{105, 106}, i) {
|
||||
v.special = append(v.special, v.Values[i])
|
||||
} else if v.Types[i] == TOKEN_TYPE_CONTROL {
|
||||
v.special = append(v.special, v.Values[i])
|
||||
}
|
||||
}
|
||||
|
246
model/process_text_spm.go
Normal file
246
model/process_text_spm.go
Normal file
@ -0,0 +1,246 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"iter"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/dlclark/regexp2"
|
||||
queue "github.com/emirpasic/gods/v2/queues/priorityqueue"
|
||||
)
|
||||
|
||||
const spmWhitespaceSep = "▁"
|
||||
|
||||
func replaceWhitespaceBySeperator(s string) string {
|
||||
return strings.ReplaceAll(s, " ", spmWhitespaceSep)
|
||||
}
|
||||
|
||||
type SentencePieceModel struct {
|
||||
maxTokenLen int
|
||||
pre *regexp2.Regexp
|
||||
vocab *Vocabulary
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*SentencePieceModel)(nil)
|
||||
|
||||
func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
|
||||
slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||
|
||||
counter := map[int]int{}
|
||||
var maxTokenLen int
|
||||
for cnt := range vocab.Types {
|
||||
switch vocab.Types[cnt] {
|
||||
case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED:
|
||||
maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt]))
|
||||
fallthrough
|
||||
default:
|
||||
counter[int(vocab.Types[cnt])] += 1
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
|
||||
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
|
||||
"max token len", maxTokenLen)
|
||||
|
||||
return SentencePieceModel{
|
||||
maxTokenLen: maxTokenLen,
|
||||
pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
|
||||
vocab: vocab,
|
||||
}
|
||||
}
|
||||
|
||||
func (spm SentencePieceModel) Is(id int32, special Special) bool {
|
||||
return spm.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
|
||||
return func(yield func(string) bool) {
|
||||
for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {
|
||||
if !yield(m.String()) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
fragments := []fragment{{value: s}}
|
||||
for _, special := range spm.vocab.SpecialVocabulary() {
|
||||
// TODO: process special tokens concurrently
|
||||
id := spm.vocab.Encode(special)
|
||||
for i := 0; i < len(fragments); i++ {
|
||||
frag := fragments[i]
|
||||
if len(frag.ids) > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var middle []fragment
|
||||
switch i := strings.Index(frag.value, special); {
|
||||
case i < 0:
|
||||
middle = append(middle, frag)
|
||||
case i > 0:
|
||||
middle = append(middle, fragment{value: frag.value[:i]})
|
||||
fallthrough
|
||||
default:
|
||||
middle = append(middle, fragment{value: special, ids: []int32{id}})
|
||||
if rest := frag.value[i+len(special):]; rest != "" {
|
||||
middle = append(middle, fragment{value: rest})
|
||||
}
|
||||
}
|
||||
|
||||
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
||||
}
|
||||
}
|
||||
slog.Debug("fragments", "frags", fragments)
|
||||
|
||||
var ids []int32
|
||||
for _, frag := range fragments {
|
||||
if len(frag.ids) > 0 {
|
||||
ids = append(ids, frag.ids...)
|
||||
continue
|
||||
}
|
||||
|
||||
for split := range spm.split(frag.value) {
|
||||
split = replaceWhitespaceBySeperator(split)
|
||||
|
||||
var sb strings.Builder
|
||||
sb.Write([]byte(split))
|
||||
if id := spm.vocab.Encode(sb.String()); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
continue
|
||||
}
|
||||
|
||||
runes := []rune(sb.String())
|
||||
pq := queue.NewWith(func(a, b any) int {
|
||||
priA := a.(*candidate)
|
||||
priB := b.(*candidate)
|
||||
if priA.score > priB.score || (priA.score == priB.score && priA.a < priB.a) {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
})
|
||||
|
||||
merges := make([]merge, len(runes))
|
||||
for r := range runes {
|
||||
merges[r] = merge{
|
||||
p: r - 1,
|
||||
n: r + 1,
|
||||
runes: []rune{runes[r]},
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("tokenizer", "merges", merges)
|
||||
|
||||
pairwise := func(a, b int) *candidate {
|
||||
if a < 0 || b >= len(runes) {
|
||||
return nil
|
||||
}
|
||||
|
||||
left, right := string(merges[a].runes), string(merges[b].runes)
|
||||
if id := spm.vocab.Encode(left + right); id >= 0 {
|
||||
return &candidate{
|
||||
a: a,
|
||||
b: b,
|
||||
score: spm.vocab.Scores[id],
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for i := range len(runes) - 1 {
|
||||
if pair := pairwise(i, i+1); pair != nil {
|
||||
pq.Enqueue(pair)
|
||||
}
|
||||
}
|
||||
|
||||
pqv := pq.Values()
|
||||
for _, v := range pqv {
|
||||
e := v.(*candidate)
|
||||
slog.Debug("candidate", "candidate", e)
|
||||
}
|
||||
|
||||
for !pq.Empty() {
|
||||
v, _ := pq.Dequeue()
|
||||
pair := v.(*candidate)
|
||||
left, right := merges[pair.a], merges[pair.b]
|
||||
|
||||
slog.Debug("pair", "left", left, "right", right)
|
||||
if len(left.runes) == 0 || len(right.runes) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if id := spm.vocab.Encode(string(left.runes) + string(right.runes)); id < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
merges[pair.a].runes = append(left.runes, right.runes...)
|
||||
merges[pair.b].runes = nil
|
||||
merges[pair.a].n = right.n
|
||||
if right.n < len(merges) {
|
||||
merges[right.n].p = pair.a
|
||||
}
|
||||
|
||||
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
||||
pq.Enqueue(pair)
|
||||
}
|
||||
|
||||
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
||||
pq.Enqueue(pair)
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("merges", "merges", merges)
|
||||
|
||||
for _, merge := range merges {
|
||||
if len(merge.runes) > 0 {
|
||||
if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
} else {
|
||||
slog.Debug("missing token", "token", string(merge.runes))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if addSpecial && len(ids) > 0 {
|
||||
if spm.vocab.AddBOS {
|
||||
if ids[0] == spm.vocab.BOS {
|
||||
slog.Warn("adding bos token to prompt which already has it", "id", spm.vocab.BOS)
|
||||
}
|
||||
|
||||
slog.Debug("adding bos token to prompt", "id", spm.vocab.BOS)
|
||||
ids = append([]int32{spm.vocab.BOS}, ids...)
|
||||
}
|
||||
|
||||
if spm.vocab.AddEOS {
|
||||
if ids[len(ids)-1] == spm.vocab.EOS {
|
||||
slog.Warn("adding eos token to prompt which already has it", "id", spm.vocab.EOS)
|
||||
}
|
||||
|
||||
slog.Debug("adding eos token to prompt", "id", spm.vocab.EOS)
|
||||
ids = append(ids, spm.vocab.EOS)
|
||||
}
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
type candidate struct {
|
||||
a, b int
|
||||
score float32
|
||||
}
|
||||
|
||||
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for _, id := range ids {
|
||||
data := spm.vocab.Decode(id)
|
||||
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
|
||||
if _, err := sb.WriteString(data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("decoded", "ids", ids, "text", sb.String())
|
||||
return sb.String(), nil
|
||||
}
|
118
model/process_text_spm_test.go
Normal file
118
model/process_text_spm_test.go
Normal file
@ -0,0 +1,118 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/ollama/ollama/convert/sentencepiece"
|
||||
)
|
||||
|
||||
func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
|
||||
t.Helper()
|
||||
|
||||
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var spm sentencepiece.ModelProto
|
||||
if err := proto.Unmarshal(bts, &spm); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
preTokenizer := `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`
|
||||
|
||||
var v Vocabulary
|
||||
|
||||
for _, piece := range spm.GetPieces() {
|
||||
v.Values = append(v.Values, piece.GetPiece())
|
||||
v.Scores = append(v.Scores, piece.GetScore())
|
||||
switch t := piece.GetType(); t {
|
||||
case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
|
||||
sentencepiece.ModelProto_SentencePiece_CONTROL,
|
||||
sentencepiece.ModelProto_SentencePiece_UNUSED,
|
||||
sentencepiece.ModelProto_SentencePiece_BYTE:
|
||||
v.Types = append(v.Types, uint32(t))
|
||||
default:
|
||||
tt := uint32(sentencepiece.ModelProto_SentencePiece_NORMAL)
|
||||
// todo parse the special tokens file
|
||||
// - this will roundtrip correctly but the <start_of_turn> and
|
||||
// <end_of_turn> tokens aren't processed
|
||||
v.Types = append(v.Types, tt)
|
||||
}
|
||||
}
|
||||
|
||||
return NewSentencePieceModel(preTokenizer, &v)
|
||||
}
|
||||
|
||||
func TestSentencePieceEncode(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
slog.SetDefault(logger)
|
||||
|
||||
tokenizer := loadSentencePieceVocab(t)
|
||||
|
||||
t.Run("basic roundtrip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []string{
|
||||
"hello",
|
||||
"hello ",
|
||||
"hello ",
|
||||
" hello",
|
||||
" hello ",
|
||||
" hello ",
|
||||
"hello world",
|
||||
"请考试我的软件!12345",
|
||||
"你好",
|
||||
"Hello 你好 world!",
|
||||
"Special characters: !@#$%^&*()_+-=[]{}|;':\",./<>?",
|
||||
"Multilingual: 你好 こんにちは Привет Hola مرحبا",
|
||||
"Numbers and symbols: 123456789 +- */",
|
||||
"Special tokens: <bos> text <eos>",
|
||||
"Code snippets: func main() { fmt.Println(\"Hello World\") }",
|
||||
"Long text: " + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " +
|
||||
"Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " +
|
||||
"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris.",
|
||||
}
|
||||
|
||||
for _, want := range cases {
|
||||
ids, err := tokenizer.Encode(want, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got, err := tokenizer.Decode(ids); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if got != want {
|
||||
t.Errorf("got %q, want %q [%#v]", got, want, ids)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("special tokens", func(t *testing.T) {
|
||||
type candidate struct {
|
||||
token string
|
||||
ids []int32
|
||||
}
|
||||
|
||||
cases := []candidate{
|
||||
{"<bos>", []int32{2}},
|
||||
{"<eos>", []int32{1}},
|
||||
}
|
||||
|
||||
for _, want := range cases {
|
||||
ids, err := tokenizer.Encode(want.token, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !slices.Equal(ids, want.ids) {
|
||||
t.Errorf("got %#v, want %#v", ids, want.ids)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
BIN
model/testdata/gemma2/tokenizer.model
vendored
Normal file
BIN
model/testdata/gemma2/tokenizer.model
vendored
Normal file
Binary file not shown.
@ -26,6 +26,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
var system []api.Message
|
||||
|
||||
isMllama := checkMllamaModelFamily(m)
|
||||
isGemma3 := checkGemma3ModelFamily(m)
|
||||
|
||||
var imageNumTokens int
|
||||
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
|
||||
@ -40,7 +41,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
n := len(msgs) - 1
|
||||
// in reverse, find all messages that fit into context window
|
||||
for i := n; i >= 0; i-- {
|
||||
if isMllama && len(msgs[i].Images) > 1 {
|
||||
if (isMllama || isGemma3) && len(msgs[i].Images) > 1 {
|
||||
return "", nil, errTooManyImages
|
||||
}
|
||||
|
||||
@ -157,3 +158,12 @@ func checkMllamaModelFamily(m *Model) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func checkGemma3ModelFamily(m *Model) bool {
|
||||
for _, arch := range m.Config.ModelFamilies {
|
||||
if arch == "gemma3" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user