This commit is contained in:
jmorganca 2025-05-12 19:15:42 -07:00
parent 18d52686de
commit 5c76074f66
4 changed files with 61 additions and 6 deletions

View File

@ -164,6 +164,7 @@ type Tensor interface {
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32, options ...RopeOption) Tensor RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32, options ...RopeOption) Tensor
RoPEMulti(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, sections [4]int32, ropeType uint32, base, scale float32) Tensor
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
Sin(ctx Context) Tensor Sin(ctx Context) Tensor

View File

@ -1104,7 +1104,7 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
ropeFactors.(*Tensor).t, ropeFactors.(*Tensor).t,
C.int(ropeDim), C.int(ropeDim),
C.int(ropeType), C.int(ropeType),
C.int(opts.DefaultContextLen), C.int(128000),
C.float(ropeBase), C.float(ropeBase),
C.float(ropeScale), C.float(ropeScale),
C.float(opts.YarnExtFactor), C.float(opts.YarnExtFactor),
@ -1115,6 +1115,37 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
} }
} }
func (t *Tensor) RoPEMulti(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, sections [4]int32, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
if ropeFactors == nil {
ropeFactors = &Tensor{b: t.b}
}
dequant := t.t
if C.ggml_is_quantized(t.t._type) {
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
}
return &Tensor{
b: t.b,
t: C.ggml_rope_multi(
ctx.(*Context).ctx,
dequant,
positionIDs.(*Tensor).t,
ropeFactors.(*Tensor).t,
C.int(ropeDim),
(*C.int)(&sections[0]),
C.int(ropeType),
C.int(128000), // Default context length
C.float(ropeBase),
C.float(ropeScale),
C.float(0.0), // ext_factor
C.float(1.0), // attn_factor
C.float(32.0), // beta_fast
C.float(1.0), // beta_slow
),
}
}
func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
return &Tensor{ return &Tensor{
b: t.b, b: t.b,

View File

@ -72,13 +72,13 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
} }
visionOutputs := m.VisionModel.Forward(ctx, pixels, grid) visionOutputs := m.VisionModel.Forward(ctx, pixels, grid)
return &chunks{Model: m, Tensor: visionOutputs}, nil return &chunks{Model: m, Tensor: visionOutputs, grid: grid}, nil
} }
type chunks struct { type chunks struct {
*Model *Model
ml.Tensor ml.Tensor
grid *Grid
dataOnce sync.Once dataOnce sync.Once
data []float32 data []float32
} }
@ -153,7 +153,28 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
} }
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) fmt.Println("Forward")
pos := make([]int32, len(batch.Positions)*4)
var grid = &Grid{}
if len(batch.Multimodal) > 0 {
image := batch.Multimodal[0].Multimodal
grid = image.(*chunk).chunks.grid
for y := 0; y < grid.Height/2; y++ {
for x := 0; x < grid.Width/2; x++ {
i := y*grid.Width/2 + x
pos[i] = batch.Positions[i]
pos[i+len(batch.Positions)] = batch.Positions[i] + int32(y)
pos[i+len(batch.Positions)*2] = batch.Positions[i] + int32(x)
pos[i+len(batch.Positions)*3] = 0
}
}
} else {
copy(pos[:len(batch.Positions)], batch.Positions)
copy(pos[len(batch.Positions):len(batch.Positions)*2], batch.Positions)
copy(pos[len(batch.Positions)*2:len(batch.Positions)*3], batch.Positions)
}
positions, err := ctx.Input().FromIntSlice(pos, len(pos))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -76,13 +76,15 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
batchSize := hiddenState.Dim(1) batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads headDim := opts.hiddenSize / opts.numHeads
sections := [4]int32{16, 24, 24, 0}
q := sa.Query.Forward(ctx, hiddenState) q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen)) q = q.RoPEMulti(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, sections, 8, opts.ropeBase, opts.ropeScale)
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen)) k = k.RoPEMulti(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, sections, 8, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)