use fast attention

This commit is contained in:
Michael Yang 2025-03-07 17:38:36 -08:00
parent 0e886595bf
commit 8934324b72
3 changed files with 8 additions and 14 deletions

View File

@ -958,9 +958,9 @@ func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) m
var tt *C.struct_ggml_tensor var tt *C.struct_ggml_tensor
switch len(strides) { switch len(strides) {
case 0: case 0:
tt = C.ggml_set_1d_inplace(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset)) tt = C.ggml_set_1d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset))
case 1: case 1:
tt = C.ggml_set_2d_inplace(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0])) tt = C.ggml_set_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0]))
default: default:
panic("unsupported number of dimensions") panic("unsupported number of dimensions")
} }

View File

@ -138,8 +138,8 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Inpu
{Token: 255999}, // "<start_of_image>"" {Token: 255999}, // "<start_of_image>""
} }
// <image_soft_token> // pad inputs with placeholders for image embeddings
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 262144}}, 256)...) imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 0}}, 256)...)
// <end_of_image> // <end_of_image>
imageInputs = append(imageInputs, input.Input{Token: 256000}) imageInputs = append(imageInputs, input.Input{Token: 256000})

View File

@ -24,17 +24,11 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op
key := sa.Key.Forward(ctx, hiddenState) key := sa.Key.Forward(ctx, hiddenState)
value := sa.Value.Forward(ctx, hiddenState) value := sa.Value.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize).Permute(ctx, 0, 2, 1, 3) query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize).Permute(ctx, 0, 2, 1, 3) key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize).Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
scores := key.Mulmat(ctx, query) attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
scores = scores.Softmax(ctx)
attention := value.Mulmat(ctx, scores)
attention = attention.Reshape(ctx, headDim, attention.Dim(1), opts.numHeads, batchSize)
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
hiddenState = sa.Output.Forward(ctx, attention) hiddenState = sa.Output.Forward(ctx, attention)