wip: implementing rope

This commit is contained in:
Bruce MacDonald 2025-04-21 18:50:36 -07:00
parent eedc969c35
commit 5ff0d538b0
6 changed files with 153 additions and 29 deletions

View File

@ -25,6 +25,7 @@ type qwen25VLModel struct {
RMSNormEPS float32 `json:"rms_norm_eps"` RMSNormEPS float32 `json:"rms_norm_eps"`
VisionModel struct { VisionModel struct {
SpatialMergeSize uint32 `json:"spatial_merge_size"` // TODO: is this set?
} `json:"vision_config"` } `json:"vision_config"`
} }

View File

@ -570,6 +570,10 @@ func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return out return out
} }
func (t *testTensor) Div(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor { func (t *testTensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented") panic("not implemented")
} }
@ -625,6 +629,10 @@ func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor {
panic("not implemented") panic("not implemented")
} }
func (t *testTensor) Exp(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Cos(ctx ml.Context) ml.Tensor { panic("not implemented") } func (t *testTensor) Cos(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) Sin(ctx ml.Context) ml.Tensor { panic("not implemented") } func (t *testTensor) Sin(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) GELU(ctx ml.Context) ml.Tensor { panic("not implemented") } func (t *testTensor) GELU(ctx ml.Context) ml.Tensor { panic("not implemented") }

View File

@ -178,6 +178,8 @@ type Tensor interface {
Neg(ctx Context) Tensor Neg(ctx Context) Tensor
Add(ctx Context, t2 Tensor) Tensor Add(ctx Context, t2 Tensor) Tensor
// Div computes the element-wise division (t1 / t2) for all values in the tensor
Div(ctx Context, t2 Tensor) Tensor
Mul(ctx Context, t2 Tensor) Tensor Mul(ctx Context, t2 Tensor) Tensor
Mulmat(ctx Context, t2 Tensor) Tensor Mulmat(ctx Context, t2 Tensor) Tensor
MulmatFullPrec(ctx Context, t2 Tensor) Tensor MulmatFullPrec(ctx Context, t2 Tensor) Tensor
@ -198,6 +200,8 @@ type Tensor interface {
Sin(ctx Context) Tensor Sin(ctx Context) Tensor
Cos(ctx Context) Tensor Cos(ctx Context) Tensor
Tanh(ctx Context) Tensor Tanh(ctx Context) Tensor
// Exp computes the element-wise exponential (e^t) for all values in the tensor
Exp(ctx Context) Tensor
GELU(ctx Context) Tensor GELU(ctx Context) Tensor
SILU(ctx Context) Tensor SILU(ctx Context) Tensor
Sigmoid(ctx Context) Tensor Sigmoid(ctx Context) Tensor

View File

@ -860,6 +860,13 @@ func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
} }
} }
func (t *Tensor) Div(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_div(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
}
func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor { func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{ return &Tensor{
b: t.b, b: t.b,
@ -1017,6 +1024,13 @@ func (t *Tensor) Sigmoid(ctx ml.Context) ml.Tensor {
} }
} }
func (t *Tensor) Exp(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_exp_inplace(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor { func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
if len(shape) != 4 { if len(shape) != 4 {
panic("expected 4 dimensions") panic("expected 4 dimensions")

View File

@ -92,7 +92,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
return nil, fmt.Errorf("failed to create tensor from image: %w", err) return nil, fmt.Errorf("failed to create tensor from image: %w", err)
} }
visionOutputs := m.VisionModel.Forward(ctx, pixelValues) visionOutputs := m.VisionModel.Forward(ctx, pixelValues, grid)
visionOutputs = m.PatchMerger.Forward(ctx, visionOutputs, m.VisionModel.eps) visionOutputs = m.PatchMerger.Forward(ctx, visionOutputs, m.VisionModel.eps)
return &imageFeatures{ return &imageFeatures{

View File

@ -2,6 +2,7 @@ package qwen25vl
import ( import (
"math" "math"
"slices"
"github.com/ollama/ollama/fs" "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
@ -166,6 +167,121 @@ func (pm *VisionPatchMerger) Forward(ctx ml.Context, x ml.Tensor, outDim, contex
return x return x
} }
func rope(ctx ml.Context, grid *Grid) ml.Tensor {
dim := 80 / 2 // TODO: get this from config
theta := float64(10000.0) // TODO: get this from config ropeTheta
merge := 2 // Merging factor for spatial dimensions
// Calculate inverse frequencies for rotation
inv := freqInv(ctx, dim, theta)
// Generate and stack position IDs for height and width dimensions
hPos := heightPos(ctx, grid, merge)
wPos := widthPos(ctx, grid, merge)
// Reshape both and stack them
tmp := hPos.Reshape(ctx, 1, hPos.Dim(0))
pos := tmp.Stack(ctx, 0, wPos.Reshape(ctx, 1, wPos.Dim(0)))
// Generate rotary embeddings
return rotEmbed(ctx, inv, grid.Width, pos)
}
// freqInv calculates the inverse frequencies for rotary embeddings
func freqInv(ctx ml.Context, dim int, theta float64) ml.Tensor {
logBase, err := ctx.Input().FromFloatSlice([]float32{float32(math.Log(theta))}, 1)
if err != nil {
panic(err) // TODO: handle error
}
// Create powers divided by dimension (0, 2, 4, ..., dim-2) / dim
powers := ctx.Arange(0, float32(dim), 2, ml.DTypeF32)
dims, err := ctx.Input().FromFloatSlice([]float32{float32(dim)}, 1)
if err != nil {
panic(err) // TODO: handle error
}
powers = powers.Div(ctx, dims)
// Calculate inverse frequencies: 1 / (theta ^ (powers/dim))
dims = powers.Mul(ctx, logBase).Exp(ctx)
ones, err := ctx.Input().FromFloatSlice(slices.Repeat([]float32{1.0}, dims.Shape()[0]), dims.Shape()...)
if err != nil {
panic(err) // TODO: handle error
}
return ones.Div(ctx, dims)
}
// heightPos generates position IDs for the height dimension
func heightPos(ctx ml.Context, grid *Grid, merge int) ml.Tensor {
// Create a slice where each row contains the same height value repeated width times
data := make([]float32, 0, grid.Height*grid.Width)
for i := 0; i < grid.Height; i++ {
data = append(data, slices.Repeat([]float32{float32(i)}, grid.Width)...)
}
// Create pos with shape [height, width]
pos, err := ctx.Input().FromFloatSlice(data, grid.Height, grid.Width)
if err != nil {
panic(err)
}
// Reshape and permute for spatial merging
pos = pos.Reshape(
ctx,
merge,
grid.Width/merge,
merge,
grid.Height/merge,
)
pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
// Flatten to 1D tensor
return pos.Reshape(ctx, pos.Dim(0)*pos.Dim(1)*pos.Dim(2)*pos.Dim(3))
}
// widthPos generates position IDs for the width dimension
func widthPos(ctx ml.Context, grid *Grid, merge int) ml.Tensor {
// Create a slice containing width values in column-major order
data := make([]float32, 0, grid.Height*grid.Width)
for i := 0; i < grid.Height; i++ {
for j := 0; j < grid.Width; j++ {
data = append(data, float32(j))
}
}
// Create pos with shape [width, height]
pos, err := ctx.Input().FromFloatSlice(data, grid.Width, grid.Height)
if err != nil {
panic(err)
}
// Reshape and permute for spatial merging
pos = pos.Reshape(
ctx,
merge,
grid.Width/merge,
merge,
grid.Height/merge,
)
pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
// Flatten to 1D tensor
return pos.Reshape(ctx, pos.Dim(0)*pos.Dim(1)*pos.Dim(2)*pos.Dim(3))
}
// rotEmbed generates rotary embeddings using inverse frequencies and position IDs
func rotEmbed(ctx ml.Context, freqInv ml.Tensor, maxSize int, pos ml.Tensor) ml.Tensor {
// Create sequence tensor [0, 1, 2, ..., maxGridSize-1]
seq := ctx.Arange(0, float32(maxSize), 1, ml.DTypeF32)
// Reshape for matrix multiplication and calculate outer product
outer := freqInv.Reshape(ctx, 1, freqInv.Shape()[0]).Mulmat(ctx, seq.Reshape(ctx, 1, maxSize))
// Flatten position IDs and use as indices to select rows from outer product
return outer.Rows(ctx, pos.Reshape(ctx, pos.Dim(0)*pos.Dim(1)))
// TODO: index position IDs and flatten
}
// VisionModel implements the Qwen vision model // VisionModel implements the Qwen vision model
type VisionModel struct { type VisionModel struct {
PatchEmbedding *PatchEmbedding PatchEmbedding *PatchEmbedding
@ -178,11 +294,6 @@ type VisionModel struct {
// Forward computes the vision model for an input tensor // Forward computes the vision model for an input tensor
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) ml.Tensor { func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) ml.Tensor {
// Calculate position IDs for 2D RoPE
numPatchesH := pixelValues.Dim(0) / m.patchSize
numPatchesW := pixelValues.Dim(1) / m.patchSize
numPatches := numPatchesH * numPatchesW
// Extract patch embeddings // Extract patch embeddings
hiddenStates := m.PatchEmbedding.Forward( hiddenStates := m.PatchEmbedding.Forward(
ctx, ctx,
@ -192,32 +303,18 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
m.patchSize, // patch size, e.g., 14 m.patchSize, // patch size, e.g., 14
) )
// Create position IDs - for Qwen2VL mRoPE we need 4 values per position rope := rope(ctx, grid)
positions := make([]int32, numPatches*4)
for h := 0; h < numPatchesH; h++ { // spatialMergeSize := 2 // TODO: get this from config
for w := 0; w < numPatchesW; w++ { // // Create the position IDs tensor with correct dimensions
idx := h*numPatchesW + w // positions := []int32{}
// For each position, store both h and w coordinates twice
positions[idx*4] = int32(h) // y coordinate
positions[idx*4+1] = int32(w) // x coordinate
positions[idx*4+2] = int32(h) // y coordinate (repeated)
positions[idx*4+3] = int32(w) // x coordinate (repeated)
}
}
// Create the position IDs tensor with correct dimensions // // Apply encoder layers
positionIDs, err := ctx.Input().FromIntSlice(positions, numPatches*4) // for _, layer := range m.Layers {
if err != nil { // hiddenStates = layer.Forward(ctx, hiddenStates, positionIDs, m.VisionModelOptions)
panic(err) // }
}
// Apply encoder layers // hiddenStates = m.PostLayerNorm.Forward(ctx, hiddenStates, m.eps)
for _, layer := range m.Layers {
hiddenStates = layer.Forward(ctx, hiddenStates, positionIDs, m.VisionModelOptions)
}
hiddenStates = m.PostLayerNorm.Forward(ctx, hiddenStates, m.eps)
return hiddenStates return hiddenStates
} }