set non-causal attention

This commit is contained in:
Michael Yang 2025-03-07 13:52:45 -08:00
parent 631fecc6d9
commit 0df1800436
6 changed files with 57 additions and 25 deletions

View File

@ -58,9 +58,6 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv["gemma3.vision.num_channels"] = p.VisionModel.NumChannels
kv["gemma3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
kv["gemma3.vision.attention.layer_norm_epsilon"] = p.VisionModel.LayerNormEpsilon
kv["tokenizer.ggml.bos_token_id"] = uint32(2)
kv["tokenizer.ggml.eot_token_id"] = uint32(1)
return kv
}

View File

@ -148,6 +148,7 @@ type Tensor interface {
View(ctx Context, offset int, shape ...int) Tensor
Permute(ctx Context, shape ...int) Tensor
Contiguous(ctx Context) Tensor
Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor
Pad(ctx Context, shape ...int) Tensor
Unpad(ctx Context, shape ...int) Tensor

View File

@ -954,6 +954,20 @@ func (t *Tensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
}
}
func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
var tt *C.struct_ggml_tensor
switch len(strides) {
case 0:
tt = C.ggml_set_1d_inplace(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset))
case 1:
tt = C.ggml_set_2d_inplace(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0]))
default:
panic("unsupported number of dimensions")
}
return &Tensor{b: t.b, t: tt}
}
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
var kqMask *C.struct_ggml_tensor
if mask != nil {

View File

@ -51,8 +51,10 @@ func New(c ml.Config) (model.Model, error) {
Types: c.Uints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
EOS: int32(1),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOT: int32(106),
AddEOT: c.Bool("tokenizer.ggml.add_eot_token", false),
},
),
ImageProcessor: newImageProcessor(c),
@ -109,35 +111,46 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Inpu
for i := range inputs {
if inputs[i].Multimodal == nil {
if len(images) > 0 {
inputs[i].Multimodal = images[0].Multimodal
inputs[i].MultimodalHash = images[0].MultimodalHash
for j := 1; j < len(images); j++ {
for j := range images {
if j == 0 {
inputs[i].Multimodal = images[j].Multimodal
inputs[i].MultimodalHash = images[j].MultimodalHash
} else {
inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
fnvHash.Reset()
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
binary.Write(fnvHash, binary.NativeEndian, images[j].MultimodalHash)
inputs[i].MultimodalHash = fnvHash.Sum64()
}
images = nil
}
images = nil
} else {
images = append(images, inputs[i])
inputs[i].Token = -1
}
}
inputs = slices.DeleteFunc(inputs, func(input input.Input) bool { return input.Token == -1 })
for i := range inputs {
if inputs[i].Token == -1 {
imageInputs := []input.Input{
{Token: 108}, // "\n\n"
{Token: 255999}, // "<start_of_image>""
}
// <image_soft_token>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 262144}}, 256)...)
// <end_of_image>
imageInputs = append(imageInputs, input.Input{Token: 256000})
inputs = append(inputs[:i], append(imageInputs, inputs[i+1:]...)...)
}
}
return inputs, nil
}
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
var embeddings ml.Tensor
if opts.Multimodal != nil {
embeddings = opts.Multimodal[0].Multimodal.(ml.Tensor)
}
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
@ -153,7 +166,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
return nil, err
}
return m.TextModel.Forward(ctx, inputs, positions, embeddings, outputs, m.Cache), nil
return m.TextModel.Forward(ctx, inputs, positions, outputs, opts.Multimodal, m.Cache), nil
}
func init() {

View File

@ -7,6 +7,7 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type TextOptions struct {
@ -165,12 +166,15 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, embeddings, outputs ml.Tensor, cache kvcache.Cache) ml.Tensor {
if embeddings == nil {
embeddings = m.TokenEmbedding.Forward(ctx, inputs)
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, multimodal []input.MultimodalIndex, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
if multimodal != nil {
visionOutputs := multimodal[0].Multimodal.(ml.Tensor)
offset := multimodal[0].Index - 1 - visionOutputs.Dim(1)
hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(0))
}
hiddenState := embeddings.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
if len(m.Layers) == gemma27BLayerCount {
m.TextOptions.largeModelScaling = true

View File

@ -4,6 +4,7 @@ import (
"cmp"
"iter"
"log/slog"
"slices"
"strings"
"sync"
@ -39,8 +40,8 @@ type Vocabulary struct {
Scores []float32
Merges []string
BOS, EOS int32
AddBOS, AddEOS bool
BOS, EOS, EOT int32
AddBOS, AddEOS, AddEOT bool
specialOnce sync.Once
special []string
@ -57,7 +58,7 @@ func (v *Vocabulary) Is(id int32, special Special) bool {
case SpecialBOS:
return id == v.BOS
case SpecialEOS:
return id == v.EOS
return id == v.EOS || id == v.EOT
default:
return false
}
@ -85,7 +86,9 @@ func (v *Vocabulary) Decode(id int32) string {
func (v *Vocabulary) SpecialVocabulary() []string {
v.specialOnce.Do(func() {
for i := range v.Values {
if v.Types[i] == TOKEN_TYPE_CONTROL {
if slices.Contains([]int{105, 106}, i) {
v.special = append(v.special, v.Values[i])
} else if v.Types[i] == TOKEN_TYPE_CONTROL {
v.special = append(v.special, v.Values[i])
}
}