add gemma vision encoder

This commit is contained in:
Michael Yang 2025-03-06 12:16:54 -08:00
parent 5f74d1fd47
commit 4b037a97dc
10 changed files with 337 additions and 34 deletions

View File

@ -5,7 +5,16 @@ import "github.com/ollama/ollama/fs/ggml"
type gemma3Model struct { type gemma3Model struct {
gemmaModel gemmaModel
TextModel gemma3TextModel `json:"text_config"` TextModel gemma3TextModel `json:"text_config"`
VisionModel gemma3VisionModel `json:"vision_config"` VisionModel struct {
NumAttentionHeads uint32 `json:"num_attention_heads"` // attention.head_count 16
LayerNormEpsilon float32 `json:"layer_norm_eps"` // attention.layer_norm_epsilon 1e-05
NumHiddenLayers uint32 `json:"num_hidden_layers"` // block_count 32
HiddenSize uint32 `json:"hidden_size"` // embedding_length 1280
IntermediateSize uint32 `json:"intermediate_size"` // feed_forward_length 5120
ImageSize uint32 `json:"image_size"` // image_size 560
NumChannels uint32 `json:"num_channels"` // num_channels 3
PatchSize uint32 `json:"patch_size"` // patch_size 14
} `json:"vision_config"`
} }
type gemma3TextModel struct { type gemma3TextModel struct {
@ -24,12 +33,6 @@ type gemma3TextModel struct {
RopeGlobalTheta float32 `json:"rope_global_base_freq"` RopeGlobalTheta float32 `json:"rope_global_base_freq"`
} }
type gemma3VisionModel struct {
ImageSize uint32 `json:"image_size"`
NumChannels uint32 `json:"num_channels"`
HiddenLayers uint32 `json:"num_hidden_layers"`
}
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV { func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t) kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3" kv["general.architecture"] = "gemma3"
@ -46,11 +49,18 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv["gemma3.text.final_logit_softcapping"] = p.TextModel.FinalLogitSoftcap kv["gemma3.text.final_logit_softcapping"] = p.TextModel.FinalLogitSoftcap
kv["gemma3.text.rope.local.freq_base"] = p.TextModel.RopeLocalTheta kv["gemma3.text.rope.local.freq_base"] = p.TextModel.RopeLocalTheta
kv["gemma3.text.rope.global.freq_base"] = p.TextModel.RopeGlobalTheta kv["gemma3.text.rope.global.freq_base"] = p.TextModel.RopeGlobalTheta
kv["gemma3.vision.block_count"] = p.VisionModel.NumHiddenLayers
kv["gemma3.vision.embedding_length"] = p.VisionModel.HiddenSize
kv["gemma3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
kv["gemma3.vision.image_size"] = p.VisionModel.ImageSize
kv["gemma3.vision.patch_size"] = p.VisionModel.PatchSize
kv["gemma3.vision.num_channels"] = p.VisionModel.NumChannels
kv["gemma3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
kv["gemma3.vision.attention.layer_norm_epsilon"] = p.VisionModel.LayerNormEpsilon
kv["tokenizer.ggml.bos_token_id"] = uint32(2) kv["tokenizer.ggml.bos_token_id"] = uint32(2)
kv["tokenizer.ggml.eot_token_id"] = uint32(1) kv["tokenizer.ggml.eot_token_id"] = uint32(1)
kv["gemma3.vision.image_size"] = p.VisionModel.ImageSize
kv["gemma3.vision.num_channels"] = p.VisionModel.NumChannels
kv["gemma3.vision.block_count"] = p.VisionModel.HiddenLayers
return kv return kv
} }
@ -59,11 +69,11 @@ func (p *gemma3Model) Replacements() []string {
"lm_head", "output", "lm_head", "output",
"model.embed_tokens", "token_embd", "model.embed_tokens", "token_embd",
"model.norm", "output_norm", "model.norm", "output_norm",
"vision_model.vision_model", "v", "vision_tower.vision_model.embeddings", "v",
"vision_tower.vision_model", "v",
"language_model.", "", "language_model.", "",
"model.layers", "blk", "model.layers", "blk",
"encoder.layers", "blk", "encoder.layers", "blk",
"vision_tower.vision_model.embeddings", "v",
"input_layernorm", "attn_norm", "input_layernorm", "attn_norm",
"self_attn.q_proj", "attn_q", "self_attn.q_proj", "attn_q",
"self_attn.q_norm", "attn_q_norm", "self_attn.q_norm", "attn_q_norm",
@ -71,11 +81,14 @@ func (p *gemma3Model) Replacements() []string {
"self_attn.k_norm", "attn_k_norm", "self_attn.k_norm", "attn_k_norm",
"self_attn.v_proj", "attn_v", "self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output", "self_attn.o_proj", "attn_output",
"self_attn.out_proj", "attn_output",
"mlp.gate_proj", "ffn_gate", "mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down", "mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up", "mlp.up_proj", "ffn_up",
"post_attention_layernorm", "post_attention_norm", "post_attention_layernorm", "post_attention_norm",
"pre_feedforward_layernorm", "ffn_norm", "pre_feedforward_layernorm", "ffn_norm",
"post_feedforward_layernorm", "post_ffw_norm", "post_feedforward_layernorm", "post_ffw_norm",
"input_projection_weight", "input_projection.weight",
"multi_modal_projector", "mm",
} }
} }

View File

@ -135,7 +135,9 @@ type Tensor interface {
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
Scale(ctx Context, s float64) Tensor Scale(ctx Context, s float64) Tensor
AvgPool1D(ctx Context, k, s, p int) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
Tanh(ctx Context) Tensor Tanh(ctx Context) Tensor

View File

@ -947,6 +947,13 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
} }
} }
func (t *Tensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_pool_1d(ctx.(*Context).ctx, t.t, C.GGML_OP_POOL_AVG, C.int(k), C.int(s), C.int(p)),
}
}
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor { func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
var kqMask *C.struct_ggml_tensor var kqMask *C.struct_ggml_tensor
if mask != nil { if mask != nil {

View File

@ -1,10 +1,15 @@
package gemma3 package gemma3
import ( import (
"fmt" "bytes"
"encoding/binary"
"hash/fnv"
"image"
"slices"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input" "github.com/ollama/ollama/model/input"
) )
@ -13,19 +18,30 @@ type Model struct {
model.Base model.Base
model.SentencePieceModel model.SentencePieceModel
//*VisionModel `gguf:"v,vision"` *VisionModel `gguf:"v,vision"`
*TextModel *TextModel
//Projector *nn.Linear `gguf:"mm.0"` *MultiModalProjector `gguf:"mm"`
ImageProcessor ImageProcessor
} }
func New(c ml.Config) (model.Model, error) { var _ model.MultimodalProcessor = (*Model)(nil)
// Verify unified config
if c.Uint("vision.block_count") == 0 { type MultiModalProjector struct {
return nil, fmt.Errorf("non-unified vision model not supported") SoftEmbNorm *nn.RMSNorm `gguf:"mm_soft_emb_norm"`
InputProjection *nn.Linear `gguf:"mm_input_projection"`
} }
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
// TODO: inputProjection must be transposed since they're incompatible with visionOutputs
visionOutputs = p.InputProjection.Weight.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mulmat(ctx, visionOutputs)
return visionOutputs
}
func New(c ml.Config) (model.Model, error) {
m := Model{ m := Model{
SentencePieceModel: model.NewSentencePieceModel( SentencePieceModel: model.NewSentencePieceModel(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
@ -40,7 +56,7 @@ func New(c ml.Config) (model.Model, error) {
}, },
), ),
ImageProcessor: newImageProcessor(c), ImageProcessor: newImageProcessor(c),
//VisionModel: newVisionModel(c), VisionModel: newVisionModel(c),
TextModel: newTextModel(c), TextModel: newTextModel(c),
} }
@ -50,7 +66,78 @@ func New(c ml.Config) (model.Model, error) {
return &m, nil return &m, nil
} }
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
f32s, err := m.ImageProcessor.ProcessImage(image)
if err != nil {
return nil, err
}
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
m.ImageProcessor.numChannels,
)
if err != nil {
return nil, err
}
positionIDs, err := ctx.FromIntSlice([]int32{0}, 1)
if err != nil {
return nil, err
}
visionOutputs := m.VisionModel.Forward(ctx, pixelValues, positionIDs)
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
patchesPerImage := m.ImageProcessor.imageSize / m.ImageProcessor.patchSize
kernelSize := patchesPerImage * patchesPerImage / 256
visionOutputs = visionOutputs.AvgPool1D(ctx, kernelSize, kernelSize, 0)
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.VisionModel.eps)
return visionOutputs, nil
}
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
var images []input.Input
fnvHash := fnv.New64a()
for i := range inputs {
if inputs[i].Multimodal == nil {
if len(images) > 0 {
inputs[i].Multimodal = images[0].Multimodal
inputs[i].MultimodalHash = images[0].MultimodalHash
for j := 1; j < len(images); j++ {
inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
fnvHash.Reset()
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
inputs[i].MultimodalHash = fnvHash.Sum64()
}
images = nil
}
} else {
images = append(images, inputs[i])
inputs[i].Token = -1
}
}
inputs = slices.DeleteFunc(inputs, func(input input.Input) bool { return input.Token == -1 })
return inputs, nil
}
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
var embeddings ml.Tensor
if opts.Multimodal != nil {
embeddings = opts.Multimodal[0].Multimodal.(ml.Tensor)
}
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil { if err != nil {
return nil, err return nil, err
@ -66,7 +153,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
return nil, err return nil, err
} }
return m.TextModel.Forward(ctx, inputs, positions, outputs, m.Cache), nil return m.TextModel.Forward(ctx, inputs, positions, embeddings, outputs, m.Cache), nil
} }
func init() { func init() {

View File

@ -160,9 +160,12 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs
return hiddenState.Add(ctx, residual) return hiddenState.Add(ctx, residual)
} }
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, cache kvcache.Cache) ml.Tensor { func (m *TextModel) Forward(ctx ml.Context, inputs, positions, embeddings, outputs ml.Tensor, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs) if embeddings == nil {
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize))) embeddings = m.TokenEmbedding.Forward(ctx, inputs)
}
hiddenState := embeddings.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
if len(m.Layers) == gemma27BLayerCount { if len(m.Layers) == gemma27BLayerCount {
m.TextOptions.largeModelScaling = true m.TextOptions.largeModelScaling = true

View File

@ -0,0 +1,171 @@
package gemma3
import (
"math"
"slices"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
var batchSize int = 1
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
headDim := opts.hiddenSize / opts.numHeads
query := sa.Query.Forward(ctx, hiddenState)
key := sa.Key.Forward(ctx, hiddenState)
value := sa.Value.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize).Permute(ctx, 0, 2, 1, 3)
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize).Permute(ctx, 0, 2, 1, 3)
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize).Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scores := key.Mulmat(ctx, query)
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
scores = scores.Softmax(ctx)
attention := value.Mulmat(ctx, scores)
attention = attention.Reshape(ctx, headDim, attention.Dim(1), opts.numHeads, batchSize)
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
hiddenState = sa.Output.Forward(ctx, attention)
return hiddenState
}
type VisionMLP struct {
FC1 *nn.Linear `gguf:"fc1"`
FC2 *nn.Linear `gguf:"fc2"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
hiddenState = mlp.FC1.Forward(ctx, hiddenState).GELU(ctx)
hiddenState = mlp.FC2.Forward(ctx, hiddenState)
return hiddenState
}
type VisionEncoderLayer struct {
LayerNorm1 *nn.LayerNorm `gguf:"layer_norm1"`
SelfAttention *VisionSelfAttention
LayerNorm2 *nn.LayerNorm `gguf:"layer_norm2"`
MLP *VisionMLP `gguf:"mlp"`
}
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenState
// self attention
hiddenState = e.LayerNorm1.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
// feed forward
hiddenState = e.LayerNorm2.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
type VisionEncoder struct {
Layers []VisionEncoderLayer
}
func (e *VisionEncoder) Forward(ctx ml.Context, hiddenState ml.Tensor, intermediateLayersIndices []uint32, opts *VisionModelOptions) (ml.Tensor, []ml.Tensor) {
var intermediateHiddenStates []ml.Tensor
for i, layer := range e.Layers {
if slices.Contains(intermediateLayersIndices, uint32(i)) {
intermediateHiddenStates = append(intermediateHiddenStates, hiddenState.Reshape(ctx, append([]int{1}, hiddenState.Shape()...)...))
}
hiddenState = layer.Forward(ctx, hiddenState, opts)
}
return hiddenState, intermediateHiddenStates
}
type PrecomputedAspectRatioEmbedding struct {
Embedding *nn.Embedding
Gate ml.Tensor `gguf:"gate"`
}
func (e *PrecomputedAspectRatioEmbedding) Forward(ctx ml.Context, hiddenState ml.Tensor, aspectRatioIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
embeddings := e.Embedding.Forward(ctx, aspectRatioIDs)
embeddings = embeddings.Reshape(ctx, opts.hiddenSize, 1, opts.numTiles)
if e.Gate != nil {
embeddings = embeddings.Mul(ctx, e.Gate)
}
return hiddenState.Add(ctx, embeddings)
}
type PrecomputedPositionEmbedding struct {
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
PositionEmbeddingGate ml.Tensor `gguf:"position_embd.gate"`
}
func (e *PrecomputedPositionEmbedding) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, numPositions int, opts *VisionModelOptions) ml.Tensor {
positionEmbedding := e.PositionEmbedding.Forward(ctx, positionIDs)
if e.PositionEmbeddingGate != nil {
positionEmbedding = positionEmbedding.Mul(ctx, e.PositionEmbeddingGate)
}
return hiddenState.Add(ctx, positionEmbedding)
}
type VisionModelOptions struct {
hiddenSize, numHeads, numTiles int
imageSize, patchSize int
eps float32
}
type VisionModel struct {
PatchEmbedding *nn.Conv2D `gguf:"patch_embedding"`
PositionEmbedding *nn.Embedding `gguf:"position_embedding"`
PostLayerNorm *nn.LayerNorm `gguf:"post_layernorm"`
Encoder *VisionEncoder `gguf:"blk"`
*VisionModelOptions
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs ml.Tensor) ml.Tensor {
numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize)
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
positions := m.PositionEmbedding.Forward(ctx, positionIDs)
hiddenState = hiddenState.Add(ctx, positions)
for _, layer := range m.Encoder.Layers {
hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions)
}
hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
return hiddenState
}
func newVisionModel(c ml.Config) *VisionModel {
return &VisionModel{
Encoder: &VisionEncoder{Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count"))},
VisionModelOptions: &VisionModelOptions{
hiddenSize: int(c.Uint("vision.embedding_length")),
numHeads: int(c.Uint("vision.attention.head_count")),
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),
eps: c.Float("vision.attention.layer_norm_epsilon"),
},
}
}

View File

@ -8,12 +8,13 @@ import (
) )
type ImageProcessor struct { type ImageProcessor struct {
imageSize, numChannels int imageSize, patchSize, numChannels int
} }
func newImageProcessor(c ml.Config) ImageProcessor { func newImageProcessor(c ml.Config) ImageProcessor {
return ImageProcessor{ return ImageProcessor{
imageSize: int(c.Uint("vision.image_size")), imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),
numChannels: int(c.Uint("vision.num_channels")), numChannels: int(c.Uint("vision.num_channels")),
} }
} }

View File

@ -144,8 +144,6 @@ func (p *ImageProcessor) splitToTiles(img image.Image, numTilesSize image.Point)
return images return images
} }
// remove the "alpha" channel by drawing over a prefilled image
//
// remove the "alpha" channel by drawing over a prefilled image // remove the "alpha" channel by drawing over a prefilled image
// //
//nolint:unused //nolint:unused

View File

@ -21,6 +21,8 @@ type SentencePieceModel struct {
vocab *Vocabulary vocab *Vocabulary
} }
var _ TextProcessor = (*SentencePieceModel)(nil)
func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel { func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5]) slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
@ -61,7 +63,7 @@ func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
} }
} }
func (spm SentencePieceModel) Encode(s string) ([]int32, error) { func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}} fragments := []fragment{{value: s}}
for _, special := range spm.vocab.SpecialVocabulary() { for _, special := range spm.vocab.SpecialVocabulary() {
// TODO: process special tokens concurrently // TODO: process special tokens concurrently
@ -196,7 +198,26 @@ func (spm SentencePieceModel) Encode(s string) ([]int32, error) {
} }
} }
} }
slog.Debug("encoded", "ids", ids)
if addSpecial && len(ids) > 0 {
if spm.vocab.AddBOS {
if ids[0] == spm.vocab.BOS {
slog.Warn("adding bos token to prompt which already has it", "id", spm.vocab.BOS)
}
slog.Debug("adding bos token to prompt", "id", spm.vocab.BOS)
ids = append([]int32{spm.vocab.BOS}, ids...)
}
if spm.vocab.AddEOS {
if ids[len(ids)-1] == spm.vocab.EOS {
slog.Warn("adding eos token to prompt which already has it", "id", spm.vocab.EOS)
}
slog.Debug("adding eos token to prompt", "id", spm.vocab.EOS)
ids = append(ids, spm.vocab.EOS)
}
}
return ids, nil return ids, nil
} }