This commit is contained in:
jmorganca 2025-05-12 19:15:42 -07:00
parent 18d52686de
commit 5c76074f66
4 changed files with 61 additions and 6 deletions

View File

@ -164,6 +164,7 @@ type Tensor interface {
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32, options ...RopeOption) Tensor
RoPEMulti(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, sections [4]int32, ropeType uint32, base, scale float32) Tensor
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
Sin(ctx Context) Tensor

View File

@ -1104,7 +1104,7 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
ropeFactors.(*Tensor).t,
C.int(ropeDim),
C.int(ropeType),
C.int(opts.DefaultContextLen),
C.int(128000),
C.float(ropeBase),
C.float(ropeScale),
C.float(opts.YarnExtFactor),
@ -1115,6 +1115,37 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
}
}
func (t *Tensor) RoPEMulti(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, sections [4]int32, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
if ropeFactors == nil {
ropeFactors = &Tensor{b: t.b}
}
dequant := t.t
if C.ggml_is_quantized(t.t._type) {
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
}
return &Tensor{
b: t.b,
t: C.ggml_rope_multi(
ctx.(*Context).ctx,
dequant,
positionIDs.(*Tensor).t,
ropeFactors.(*Tensor).t,
C.int(ropeDim),
(*C.int)(&sections[0]),
C.int(ropeType),
C.int(128000), // Default context length
C.float(ropeBase),
C.float(ropeScale),
C.float(0.0), // ext_factor
C.float(1.0), // attn_factor
C.float(32.0), // beta_fast
C.float(1.0), // beta_slow
),
}
}
func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
return &Tensor{
b: t.b,

View File

@ -72,13 +72,13 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
}
visionOutputs := m.VisionModel.Forward(ctx, pixels, grid)
return &chunks{Model: m, Tensor: visionOutputs}, nil
return &chunks{Model: m, Tensor: visionOutputs, grid: grid}, nil
}
type chunks struct {
*Model
ml.Tensor
grid *Grid
dataOnce sync.Once
data []float32
}
@ -153,7 +153,28 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
fmt.Println("Forward")
pos := make([]int32, len(batch.Positions)*4)
var grid = &Grid{}
if len(batch.Multimodal) > 0 {
image := batch.Multimodal[0].Multimodal
grid = image.(*chunk).chunks.grid
for y := 0; y < grid.Height/2; y++ {
for x := 0; x < grid.Width/2; x++ {
i := y*grid.Width/2 + x
pos[i] = batch.Positions[i]
pos[i+len(batch.Positions)] = batch.Positions[i] + int32(y)
pos[i+len(batch.Positions)*2] = batch.Positions[i] + int32(x)
pos[i+len(batch.Positions)*3] = 0
}
}
} else {
copy(pos[:len(batch.Positions)], batch.Positions)
copy(pos[len(batch.Positions):len(batch.Positions)*2], batch.Positions)
copy(pos[len(batch.Positions)*2:len(batch.Positions)*3], batch.Positions)
}
positions, err := ctx.Input().FromIntSlice(pos, len(pos))
if err != nil {
return nil, err
}

View File

@ -76,13 +76,15 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
sections := [4]int32{16, 24, 24, 0}
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen))
q = q.RoPEMulti(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, sections, 8, opts.ropeBase, opts.ropeScale)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen))
k = k.RoPEMulti(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, sections, 8, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)