set non-causal attention

This commit is contained in:
Michael Yang 2025-03-07 13:52:45 -08:00
parent 631fecc6d9
commit 0df1800436
6 changed files with 57 additions and 25 deletions

View File

@ -58,9 +58,6 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv["gemma3.vision.num_channels"] = p.VisionModel.NumChannels kv["gemma3.vision.num_channels"] = p.VisionModel.NumChannels
kv["gemma3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads kv["gemma3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
kv["gemma3.vision.attention.layer_norm_epsilon"] = p.VisionModel.LayerNormEpsilon kv["gemma3.vision.attention.layer_norm_epsilon"] = p.VisionModel.LayerNormEpsilon
kv["tokenizer.ggml.bos_token_id"] = uint32(2)
kv["tokenizer.ggml.eot_token_id"] = uint32(1)
return kv return kv
} }

View File

@ -148,6 +148,7 @@ type Tensor interface {
View(ctx Context, offset int, shape ...int) Tensor View(ctx Context, offset int, shape ...int) Tensor
Permute(ctx Context, shape ...int) Tensor Permute(ctx Context, shape ...int) Tensor
Contiguous(ctx Context) Tensor Contiguous(ctx Context) Tensor
Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor
Pad(ctx Context, shape ...int) Tensor Pad(ctx Context, shape ...int) Tensor
Unpad(ctx Context, shape ...int) Tensor Unpad(ctx Context, shape ...int) Tensor

View File

@ -954,6 +954,20 @@ func (t *Tensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
} }
} }
func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
var tt *C.struct_ggml_tensor
switch len(strides) {
case 0:
tt = C.ggml_set_1d_inplace(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset))
case 1:
tt = C.ggml_set_2d_inplace(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0]))
default:
panic("unsupported number of dimensions")
}
return &Tensor{b: t.b, t: tt}
}
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor { func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
var kqMask *C.struct_ggml_tensor var kqMask *C.struct_ggml_tensor
if mask != nil { if mask != nil {

View File

@ -51,8 +51,10 @@ func New(c ml.Config) (model.Model, error) {
Types: c.Uints("tokenizer.ggml.token_type"), Types: c.Uints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), EOS: int32(1),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOT: int32(106),
AddEOT: c.Bool("tokenizer.ggml.add_eot_token", false),
}, },
), ),
ImageProcessor: newImageProcessor(c), ImageProcessor: newImageProcessor(c),
@ -109,35 +111,46 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Inpu
for i := range inputs { for i := range inputs {
if inputs[i].Multimodal == nil { if inputs[i].Multimodal == nil {
if len(images) > 0 { for j := range images {
inputs[i].Multimodal = images[0].Multimodal if j == 0 {
inputs[i].MultimodalHash = images[0].MultimodalHash inputs[i].Multimodal = images[j].Multimodal
for j := 1; j < len(images); j++ { inputs[i].MultimodalHash = images[j].MultimodalHash
} else {
inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3) inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
fnvHash.Reset() fnvHash.Reset()
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash) binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash) binary.Write(fnvHash, binary.NativeEndian, images[j].MultimodalHash)
inputs[i].MultimodalHash = fnvHash.Sum64() inputs[i].MultimodalHash = fnvHash.Sum64()
} }
images = nil
} }
images = nil
} else { } else {
images = append(images, inputs[i]) images = append(images, inputs[i])
inputs[i].Token = -1 inputs[i].Token = -1
} }
} }
inputs = slices.DeleteFunc(inputs, func(input input.Input) bool { return input.Token == -1 }) for i := range inputs {
if inputs[i].Token == -1 {
imageInputs := []input.Input{
{Token: 108}, // "\n\n"
{Token: 255999}, // "<start_of_image>""
}
// <image_soft_token>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 262144}}, 256)...)
// <end_of_image>
imageInputs = append(imageInputs, input.Input{Token: 256000})
inputs = append(inputs[:i], append(imageInputs, inputs[i+1:]...)...)
}
}
return inputs, nil return inputs, nil
} }
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
var embeddings ml.Tensor
if opts.Multimodal != nil {
embeddings = opts.Multimodal[0].Multimodal.(ml.Tensor)
}
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil { if err != nil {
return nil, err return nil, err
@ -153,7 +166,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
return nil, err return nil, err
} }
return m.TextModel.Forward(ctx, inputs, positions, embeddings, outputs, m.Cache), nil return m.TextModel.Forward(ctx, inputs, positions, outputs, opts.Multimodal, m.Cache), nil
} }
func init() { func init() {

View File

@ -7,6 +7,7 @@ import (
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
) )
type TextOptions struct { type TextOptions struct {
@ -165,12 +166,15 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
return hiddenState.Add(ctx, residual) return hiddenState.Add(ctx, residual)
} }
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, embeddings, outputs ml.Tensor, cache kvcache.Cache) ml.Tensor { func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, multimodal []input.MultimodalIndex, cache kvcache.Cache) ml.Tensor {
if embeddings == nil { hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
embeddings = m.TokenEmbedding.Forward(ctx, inputs) if multimodal != nil {
visionOutputs := multimodal[0].Multimodal.(ml.Tensor)
offset := multimodal[0].Index - 1 - visionOutputs.Dim(1)
hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(0))
} }
hiddenState := embeddings.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize))) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
if len(m.Layers) == gemma27BLayerCount { if len(m.Layers) == gemma27BLayerCount {
m.TextOptions.largeModelScaling = true m.TextOptions.largeModelScaling = true

View File

@ -4,6 +4,7 @@ import (
"cmp" "cmp"
"iter" "iter"
"log/slog" "log/slog"
"slices"
"strings" "strings"
"sync" "sync"
@ -39,8 +40,8 @@ type Vocabulary struct {
Scores []float32 Scores []float32
Merges []string Merges []string
BOS, EOS int32 BOS, EOS, EOT int32
AddBOS, AddEOS bool AddBOS, AddEOS, AddEOT bool
specialOnce sync.Once specialOnce sync.Once
special []string special []string
@ -57,7 +58,7 @@ func (v *Vocabulary) Is(id int32, special Special) bool {
case SpecialBOS: case SpecialBOS:
return id == v.BOS return id == v.BOS
case SpecialEOS: case SpecialEOS:
return id == v.EOS return id == v.EOS || id == v.EOT
default: default:
return false return false
} }
@ -85,7 +86,9 @@ func (v *Vocabulary) Decode(id int32) string {
func (v *Vocabulary) SpecialVocabulary() []string { func (v *Vocabulary) SpecialVocabulary() []string {
v.specialOnce.Do(func() { v.specialOnce.Do(func() {
for i := range v.Values { for i := range v.Values {
if v.Types[i] == TOKEN_TYPE_CONTROL { if slices.Contains([]int{105, 106}, i) {
v.special = append(v.special, v.Values[i])
} else if v.Types[i] == TOKEN_TYPE_CONTROL {
v.special = append(v.special, v.Values[i]) v.special = append(v.special, v.Values[i])
} }
} }