wip: implementing rope
This commit is contained in:
parent
eedc969c35
commit
5ff0d538b0
@ -25,6 +25,7 @@ type qwen25VLModel struct {
|
|||||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||||
|
|
||||||
VisionModel struct {
|
VisionModel struct {
|
||||||
|
SpatialMergeSize uint32 `json:"spatial_merge_size"` // TODO: is this set?
|
||||||
} `json:"vision_config"`
|
} `json:"vision_config"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -570,6 +570,10 @@ func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Div(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (t *testTensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
func (t *testTensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
@ -625,6 +629,10 @@ func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor {
|
|||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Exp(ctx ml.Context) ml.Tensor {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (t *testTensor) Cos(ctx ml.Context) ml.Tensor { panic("not implemented") }
|
func (t *testTensor) Cos(ctx ml.Context) ml.Tensor { panic("not implemented") }
|
||||||
func (t *testTensor) Sin(ctx ml.Context) ml.Tensor { panic("not implemented") }
|
func (t *testTensor) Sin(ctx ml.Context) ml.Tensor { panic("not implemented") }
|
||||||
func (t *testTensor) GELU(ctx ml.Context) ml.Tensor { panic("not implemented") }
|
func (t *testTensor) GELU(ctx ml.Context) ml.Tensor { panic("not implemented") }
|
||||||
|
@ -178,6 +178,8 @@ type Tensor interface {
|
|||||||
|
|
||||||
Neg(ctx Context) Tensor
|
Neg(ctx Context) Tensor
|
||||||
Add(ctx Context, t2 Tensor) Tensor
|
Add(ctx Context, t2 Tensor) Tensor
|
||||||
|
// Div computes the element-wise division (t1 / t2) for all values in the tensor
|
||||||
|
Div(ctx Context, t2 Tensor) Tensor
|
||||||
Mul(ctx Context, t2 Tensor) Tensor
|
Mul(ctx Context, t2 Tensor) Tensor
|
||||||
Mulmat(ctx Context, t2 Tensor) Tensor
|
Mulmat(ctx Context, t2 Tensor) Tensor
|
||||||
MulmatFullPrec(ctx Context, t2 Tensor) Tensor
|
MulmatFullPrec(ctx Context, t2 Tensor) Tensor
|
||||||
@ -198,6 +200,8 @@ type Tensor interface {
|
|||||||
Sin(ctx Context) Tensor
|
Sin(ctx Context) Tensor
|
||||||
Cos(ctx Context) Tensor
|
Cos(ctx Context) Tensor
|
||||||
Tanh(ctx Context) Tensor
|
Tanh(ctx Context) Tensor
|
||||||
|
// Exp computes the element-wise exponential (e^t) for all values in the tensor
|
||||||
|
Exp(ctx Context) Tensor
|
||||||
GELU(ctx Context) Tensor
|
GELU(ctx Context) Tensor
|
||||||
SILU(ctx Context) Tensor
|
SILU(ctx Context) Tensor
|
||||||
Sigmoid(ctx Context) Tensor
|
Sigmoid(ctx Context) Tensor
|
||||||
|
@ -860,6 +860,13 @@ func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Tensor) Div(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
|
t: C.ggml_div(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
b: t.b,
|
b: t.b,
|
||||||
@ -1017,6 +1024,13 @@ func (t *Tensor) Sigmoid(ctx ml.Context) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Tensor) Exp(ctx ml.Context) ml.Tensor {
|
||||||
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
|
t: C.ggml_exp_inplace(ctx.(*Context).ctx, t.t),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
|
func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||||
if len(shape) != 4 {
|
if len(shape) != 4 {
|
||||||
panic("expected 4 dimensions")
|
panic("expected 4 dimensions")
|
||||||
|
@ -92,7 +92,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
|||||||
return nil, fmt.Errorf("failed to create tensor from image: %w", err)
|
return nil, fmt.Errorf("failed to create tensor from image: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
visionOutputs := m.VisionModel.Forward(ctx, pixelValues, grid)
|
||||||
visionOutputs = m.PatchMerger.Forward(ctx, visionOutputs, m.VisionModel.eps)
|
visionOutputs = m.PatchMerger.Forward(ctx, visionOutputs, m.VisionModel.eps)
|
||||||
|
|
||||||
return &imageFeatures{
|
return &imageFeatures{
|
||||||
|
@ -2,6 +2,7 @@ package qwen25vl
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"math"
|
"math"
|
||||||
|
"slices"
|
||||||
|
|
||||||
"github.com/ollama/ollama/fs"
|
"github.com/ollama/ollama/fs"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
@ -166,6 +167,121 @@ func (pm *VisionPatchMerger) Forward(ctx ml.Context, x ml.Tensor, outDim, contex
|
|||||||
return x
|
return x
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func rope(ctx ml.Context, grid *Grid) ml.Tensor {
|
||||||
|
dim := 80 / 2 // TODO: get this from config
|
||||||
|
theta := float64(10000.0) // TODO: get this from config ropeTheta
|
||||||
|
merge := 2 // Merging factor for spatial dimensions
|
||||||
|
|
||||||
|
// Calculate inverse frequencies for rotation
|
||||||
|
inv := freqInv(ctx, dim, theta)
|
||||||
|
|
||||||
|
// Generate and stack position IDs for height and width dimensions
|
||||||
|
hPos := heightPos(ctx, grid, merge)
|
||||||
|
wPos := widthPos(ctx, grid, merge)
|
||||||
|
// Reshape both and stack them
|
||||||
|
tmp := hPos.Reshape(ctx, 1, hPos.Dim(0))
|
||||||
|
pos := tmp.Stack(ctx, 0, wPos.Reshape(ctx, 1, wPos.Dim(0)))
|
||||||
|
|
||||||
|
// Generate rotary embeddings
|
||||||
|
return rotEmbed(ctx, inv, grid.Width, pos)
|
||||||
|
}
|
||||||
|
|
||||||
|
// freqInv calculates the inverse frequencies for rotary embeddings
|
||||||
|
func freqInv(ctx ml.Context, dim int, theta float64) ml.Tensor {
|
||||||
|
logBase, err := ctx.Input().FromFloatSlice([]float32{float32(math.Log(theta))}, 1)
|
||||||
|
if err != nil {
|
||||||
|
panic(err) // TODO: handle error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create powers divided by dimension (0, 2, 4, ..., dim-2) / dim
|
||||||
|
powers := ctx.Arange(0, float32(dim), 2, ml.DTypeF32)
|
||||||
|
dims, err := ctx.Input().FromFloatSlice([]float32{float32(dim)}, 1)
|
||||||
|
if err != nil {
|
||||||
|
panic(err) // TODO: handle error
|
||||||
|
}
|
||||||
|
powers = powers.Div(ctx, dims)
|
||||||
|
|
||||||
|
// Calculate inverse frequencies: 1 / (theta ^ (powers/dim))
|
||||||
|
dims = powers.Mul(ctx, logBase).Exp(ctx)
|
||||||
|
ones, err := ctx.Input().FromFloatSlice(slices.Repeat([]float32{1.0}, dims.Shape()[0]), dims.Shape()...)
|
||||||
|
if err != nil {
|
||||||
|
panic(err) // TODO: handle error
|
||||||
|
}
|
||||||
|
return ones.Div(ctx, dims)
|
||||||
|
}
|
||||||
|
|
||||||
|
// heightPos generates position IDs for the height dimension
|
||||||
|
func heightPos(ctx ml.Context, grid *Grid, merge int) ml.Tensor {
|
||||||
|
// Create a slice where each row contains the same height value repeated width times
|
||||||
|
data := make([]float32, 0, grid.Height*grid.Width)
|
||||||
|
for i := 0; i < grid.Height; i++ {
|
||||||
|
data = append(data, slices.Repeat([]float32{float32(i)}, grid.Width)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create pos with shape [height, width]
|
||||||
|
pos, err := ctx.Input().FromFloatSlice(data, grid.Height, grid.Width)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reshape and permute for spatial merging
|
||||||
|
pos = pos.Reshape(
|
||||||
|
ctx,
|
||||||
|
merge,
|
||||||
|
grid.Width/merge,
|
||||||
|
merge,
|
||||||
|
grid.Height/merge,
|
||||||
|
)
|
||||||
|
pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
|
|
||||||
|
// Flatten to 1D tensor
|
||||||
|
return pos.Reshape(ctx, pos.Dim(0)*pos.Dim(1)*pos.Dim(2)*pos.Dim(3))
|
||||||
|
}
|
||||||
|
|
||||||
|
// widthPos generates position IDs for the width dimension
|
||||||
|
func widthPos(ctx ml.Context, grid *Grid, merge int) ml.Tensor {
|
||||||
|
// Create a slice containing width values in column-major order
|
||||||
|
data := make([]float32, 0, grid.Height*grid.Width)
|
||||||
|
for i := 0; i < grid.Height; i++ {
|
||||||
|
for j := 0; j < grid.Width; j++ {
|
||||||
|
data = append(data, float32(j))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create pos with shape [width, height]
|
||||||
|
pos, err := ctx.Input().FromFloatSlice(data, grid.Width, grid.Height)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reshape and permute for spatial merging
|
||||||
|
pos = pos.Reshape(
|
||||||
|
ctx,
|
||||||
|
merge,
|
||||||
|
grid.Width/merge,
|
||||||
|
merge,
|
||||||
|
grid.Height/merge,
|
||||||
|
)
|
||||||
|
pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
|
|
||||||
|
// Flatten to 1D tensor
|
||||||
|
return pos.Reshape(ctx, pos.Dim(0)*pos.Dim(1)*pos.Dim(2)*pos.Dim(3))
|
||||||
|
}
|
||||||
|
|
||||||
|
// rotEmbed generates rotary embeddings using inverse frequencies and position IDs
|
||||||
|
func rotEmbed(ctx ml.Context, freqInv ml.Tensor, maxSize int, pos ml.Tensor) ml.Tensor {
|
||||||
|
// Create sequence tensor [0, 1, 2, ..., maxGridSize-1]
|
||||||
|
seq := ctx.Arange(0, float32(maxSize), 1, ml.DTypeF32)
|
||||||
|
|
||||||
|
// Reshape for matrix multiplication and calculate outer product
|
||||||
|
outer := freqInv.Reshape(ctx, 1, freqInv.Shape()[0]).Mulmat(ctx, seq.Reshape(ctx, 1, maxSize))
|
||||||
|
|
||||||
|
// Flatten position IDs and use as indices to select rows from outer product
|
||||||
|
return outer.Rows(ctx, pos.Reshape(ctx, pos.Dim(0)*pos.Dim(1)))
|
||||||
|
|
||||||
|
// TODO: index position IDs and flatten
|
||||||
|
}
|
||||||
|
|
||||||
// VisionModel implements the Qwen vision model
|
// VisionModel implements the Qwen vision model
|
||||||
type VisionModel struct {
|
type VisionModel struct {
|
||||||
PatchEmbedding *PatchEmbedding
|
PatchEmbedding *PatchEmbedding
|
||||||
@ -178,11 +294,6 @@ type VisionModel struct {
|
|||||||
|
|
||||||
// Forward computes the vision model for an input tensor
|
// Forward computes the vision model for an input tensor
|
||||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) ml.Tensor {
|
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) ml.Tensor {
|
||||||
// Calculate position IDs for 2D RoPE
|
|
||||||
numPatchesH := pixelValues.Dim(0) / m.patchSize
|
|
||||||
numPatchesW := pixelValues.Dim(1) / m.patchSize
|
|
||||||
numPatches := numPatchesH * numPatchesW
|
|
||||||
|
|
||||||
// Extract patch embeddings
|
// Extract patch embeddings
|
||||||
hiddenStates := m.PatchEmbedding.Forward(
|
hiddenStates := m.PatchEmbedding.Forward(
|
||||||
ctx,
|
ctx,
|
||||||
@ -192,32 +303,18 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
|||||||
m.patchSize, // patch size, e.g., 14
|
m.patchSize, // patch size, e.g., 14
|
||||||
)
|
)
|
||||||
|
|
||||||
// Create position IDs - for Qwen2VL mRoPE we need 4 values per position
|
rope := rope(ctx, grid)
|
||||||
positions := make([]int32, numPatches*4)
|
|
||||||
|
|
||||||
for h := 0; h < numPatchesH; h++ {
|
// spatialMergeSize := 2 // TODO: get this from config
|
||||||
for w := 0; w < numPatchesW; w++ {
|
// // Create the position IDs tensor with correct dimensions
|
||||||
idx := h*numPatchesW + w
|
// positions := []int32{}
|
||||||
// For each position, store both h and w coordinates twice
|
|
||||||
positions[idx*4] = int32(h) // y coordinate
|
|
||||||
positions[idx*4+1] = int32(w) // x coordinate
|
|
||||||
positions[idx*4+2] = int32(h) // y coordinate (repeated)
|
|
||||||
positions[idx*4+3] = int32(w) // x coordinate (repeated)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the position IDs tensor with correct dimensions
|
// // Apply encoder layers
|
||||||
positionIDs, err := ctx.Input().FromIntSlice(positions, numPatches*4)
|
// for _, layer := range m.Layers {
|
||||||
if err != nil {
|
// hiddenStates = layer.Forward(ctx, hiddenStates, positionIDs, m.VisionModelOptions)
|
||||||
panic(err)
|
// }
|
||||||
}
|
|
||||||
|
|
||||||
// Apply encoder layers
|
// hiddenStates = m.PostLayerNorm.Forward(ctx, hiddenStates, m.eps)
|
||||||
for _, layer := range m.Layers {
|
|
||||||
hiddenStates = layer.Forward(ctx, hiddenStates, positionIDs, m.VisionModelOptions)
|
|
||||||
}
|
|
||||||
|
|
||||||
hiddenStates = m.PostLayerNorm.Forward(ctx, hiddenStates, m.eps)
|
|
||||||
return hiddenStates
|
return hiddenStates
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user