From 760e8fa6561206a76c504e8f5f0da996aa25bffb Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 7 Feb 2025 19:31:50 -0800 Subject: [PATCH] tmp --- model/bert/model.go | 84 ++++++++++++++++++++++++++++++--------------- 1 file changed, 56 insertions(+), 28 deletions(-) diff --git a/model/bert/model.go b/model/bert/model.go index e4f0bec27..6a30925c0 100644 --- a/model/bert/model.go +++ b/model/bert/model.go @@ -1,7 +1,6 @@ package bert import ( - "fmt" "math" "github.com/ollama/ollama/ml" @@ -13,21 +12,32 @@ func init() { model.Register("bert", New) } +type PoolingType int + +const ( + PoolingTypeNone PoolingType = iota + PoolingTypeMean + PoolingTypeCLS + PoolingTypeLast + PoolingTypeRank +) + type Options struct { hiddenSize, numHeads int64 eps float32 + poolingType PoolingType } type Model struct { model.Base model.BytePairEncoding - TokenEmbedding *nn.Embedding `ggml:"token_embd"` - TypeEmbedding *nn.Embedding `ggml:"type_embd,alt:token_types"` - PositionEmbedding *nn.Embedding `ggml:"position_embd"` - TokenEmbeddingNorm *nn.LayerNorm `ggml:"token_embd_norm"` + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + TypeEmbedding *nn.Embedding `gguf:"type_embd,alt:token_types"` + PositionEmbedding *nn.Embedding `gguf:"position_embd"` + TokenEmbeddingNorm *nn.LayerNorm `gguf:"token_embd_norm"` - Layers []EncoderLayer `ggml:"blk"` + Layers []EncoderLayer `gguf:"blk"` *Options } @@ -38,33 +48,49 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { if err != nil { return nil, err } - fmt.Println("inputs", inputs.Shape(), ml.Dump(inputs)) types, err := ctx.FromIntSlice([]int32{0}, 1) if err != nil { return nil, err } - fmt.Println("types", types.Shape(), ml.Dump(types)) positions, err := ctx.FromIntSlice(opts.Positions(), len(opts.Positions())) if err != nil { return nil, err } - fmt.Println("positions", positions.Shape(), ml.Dump(positions)) hiddenState := m.TokenEmbedding.Forward(ctx, inputs) - fmt.Println("TokenEmbedding.Forward", hiddenState.Shape(), ml.Dump(hiddenState)) - return hiddenState, nil hiddenState = hiddenState.Add(ctx, m.TypeEmbedding.Forward(ctx, types)) - fmt.Println("TypeEmbedding.Forward", hiddenState.Shape(), ml.Dump(hiddenState)) hiddenState = hiddenState.Add(ctx, m.PositionEmbedding.Forward(ctx, positions)) - fmt.Println("PositionEmbedding.Forward", hiddenState.Shape(), ml.Dump(hiddenState)) hiddenState = m.TokenEmbeddingNorm.Forward(ctx, hiddenState, m.eps) - fmt.Println("TokenEmbeddingNorm.Forward", hiddenState.Shape(), ml.Dump(hiddenState)) for i, layer := range m.Layers { hiddenState = layer.Forward(ctx, hiddenState, positions, opts.Cache.Sub(i), m.Options) - fmt.Println("EncoderLayer.Forward", i, hiddenState.Shape(), ml.Dump(hiddenState)) + } + + switch m.poolingType { + case PoolingTypeMean: + sum := func(s []int32) (sum int32) { + for _, v := range s { + sum += v + } + + return + } + + // TODO: handle batch + f32s := make([]float32, len(opts.Positions())*len(opts.Positions())) + for i := range opts.Positions() { + f32s[i] = 1 / float32(sum(opts.Positions())) + } + + means, err := ctx.FromFloatSlice(f32s, len(opts.Positions()), len(opts.Positions())) + if err != nil { + return nil, err + } + + hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + hiddenState = hiddenState.Mulmat(ctx, means) } return hiddenState, nil @@ -72,9 +98,9 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { type EncoderLayer struct { *SelfAttention - MLPNorm *nn.LayerNorm `ggml:"attn_output_norm"` + MLPNorm *nn.LayerNorm `gguf:"attn_output_norm"` *MLP - LayerOutputNorm *nn.LayerNorm `ggml:"ffn_output_norm"` + LayerOutputNorm *nn.LayerNorm `gguf:"layer_output_norm"` } func (e *EncoderLayer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor { @@ -82,19 +108,19 @@ func (e *EncoderLayer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tenso hiddenState = e.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts) hiddenState = hiddenState.Add(ctx, residual) + hiddenState = e.MLPNorm.Forward(ctx, hiddenState, opts.eps) residual = hiddenState - hiddenState = e.MLPNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = e.MLP.Forward(ctx, hiddenState, opts) hiddenState = hiddenState.Add(ctx, residual) return e.LayerOutputNorm.Forward(ctx, hiddenState, opts.eps) } type SelfAttention struct { - Query *nn.Linear `ggml:"attn_q"` - Key *nn.Linear `ggml:"attn_k"` - Value *nn.Linear `ggml:"attn_v"` - Output *nn.Linear `ggml:"attn_output"` + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` } func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor { @@ -105,7 +131,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) key := sa.Key.Forward(ctx, hiddenState) - key = key.Reshape(ctx, opts.numHeads, headDim, batchSize) + key = key.Reshape(ctx, headDim, opts.numHeads, batchSize) value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numHeads, batchSize) @@ -128,8 +154,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } type MLP struct { - Up *nn.Linear `ggml:"ffn_up"` - Down *nn.Linear `ggml:"ffn_down"` + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` } func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor { @@ -138,6 +164,7 @@ func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml func New(c ml.Config) (model.Model, error) { return &Model{ + Layers: make([]EncoderLayer, c.Uint("block_count")), BytePairEncoding: model.NewBytePairEncoding( c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ @@ -149,9 +176,10 @@ func New(c ml.Config) (model.Model, error) { }, ), Options: &Options{ - hiddenSize: int64(c.Uint("embedding_length")), - numHeads: int64(c.Uint("attention.head_count")), - eps: c.Float("attention.layer_norm_epsilon"), + hiddenSize: int64(c.Uint("embedding_length")), + numHeads: int64(c.Uint("attention.head_count")), + eps: c.Float("attention.layer_norm_epsilon"), + poolingType: PoolingType(c.Uint("pooling_type")), }, }, nil }