wip
This commit is contained in:
parent
18d52686de
commit
5c76074f66
@ -164,6 +164,7 @@ type Tensor interface {
|
||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
|
||||
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32, options ...RopeOption) Tensor
|
||||
RoPEMulti(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, sections [4]int32, ropeType uint32, base, scale float32) Tensor
|
||||
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
|
||||
Sin(ctx Context) Tensor
|
||||
|
@ -1104,7 +1104,7 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
|
||||
ropeFactors.(*Tensor).t,
|
||||
C.int(ropeDim),
|
||||
C.int(ropeType),
|
||||
C.int(opts.DefaultContextLen),
|
||||
C.int(128000),
|
||||
C.float(ropeBase),
|
||||
C.float(ropeScale),
|
||||
C.float(opts.YarnExtFactor),
|
||||
@ -1115,6 +1115,37 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) RoPEMulti(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, sections [4]int32, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
|
||||
if ropeFactors == nil {
|
||||
ropeFactors = &Tensor{b: t.b}
|
||||
}
|
||||
|
||||
dequant := t.t
|
||||
if C.ggml_is_quantized(t.t._type) {
|
||||
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_rope_multi(
|
||||
ctx.(*Context).ctx,
|
||||
dequant,
|
||||
positionIDs.(*Tensor).t,
|
||||
ropeFactors.(*Tensor).t,
|
||||
C.int(ropeDim),
|
||||
(*C.int)(§ions[0]),
|
||||
C.int(ropeType),
|
||||
C.int(128000), // Default context length
|
||||
C.float(ropeBase),
|
||||
C.float(ropeScale),
|
||||
C.float(0.0), // ext_factor
|
||||
C.float(1.0), // attn_factor
|
||||
C.float(32.0), // beta_fast
|
||||
C.float(1.0), // beta_slow
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
|
@ -72,13 +72,13 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
||||
}
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixels, grid)
|
||||
return &chunks{Model: m, Tensor: visionOutputs}, nil
|
||||
return &chunks{Model: m, Tensor: visionOutputs, grid: grid}, nil
|
||||
}
|
||||
|
||||
type chunks struct {
|
||||
*Model
|
||||
ml.Tensor
|
||||
|
||||
grid *Grid
|
||||
dataOnce sync.Once
|
||||
data []float32
|
||||
}
|
||||
@ -153,7 +153,28 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
fmt.Println("Forward")
|
||||
pos := make([]int32, len(batch.Positions)*4)
|
||||
var grid = &Grid{}
|
||||
if len(batch.Multimodal) > 0 {
|
||||
image := batch.Multimodal[0].Multimodal
|
||||
grid = image.(*chunk).chunks.grid
|
||||
for y := 0; y < grid.Height/2; y++ {
|
||||
for x := 0; x < grid.Width/2; x++ {
|
||||
i := y*grid.Width/2 + x
|
||||
pos[i] = batch.Positions[i]
|
||||
pos[i+len(batch.Positions)] = batch.Positions[i] + int32(y)
|
||||
pos[i+len(batch.Positions)*2] = batch.Positions[i] + int32(x)
|
||||
pos[i+len(batch.Positions)*3] = 0
|
||||
}
|
||||
}
|
||||
} else {
|
||||
copy(pos[:len(batch.Positions)], batch.Positions)
|
||||
copy(pos[len(batch.Positions):len(batch.Positions)*2], batch.Positions)
|
||||
copy(pos[len(batch.Positions)*2:len(batch.Positions)*3], batch.Positions)
|
||||
}
|
||||
|
||||
positions, err := ctx.Input().FromIntSlice(pos, len(pos))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -76,13 +76,15 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
sections := [4]int32{16, 24, 24, 0}
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen))
|
||||
q = q.RoPEMulti(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, sections, 8, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen))
|
||||
k = k.RoPEMulti(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, sections, 8, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
Loading…
x
Reference in New Issue
Block a user