simplify by doing operations in Go rather than with tensors
Co-Authored-By: Michael Yang <2372640+mxyng@users.noreply.github.com>
This commit is contained in:
parent
80498f76de
commit
0f0136d419
@ -1,8 +1,8 @@
|
|||||||
package qwen25vl
|
package qwen25vl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"slices"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/fs"
|
"github.com/ollama/ollama/fs"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
@ -167,121 +167,6 @@ func (pm *VisionPatchMerger) Forward(ctx ml.Context, x ml.Tensor, outDim, contex
|
|||||||
return x
|
return x
|
||||||
}
|
}
|
||||||
|
|
||||||
func rope(ctx ml.Context, grid *Grid) ml.Tensor {
|
|
||||||
dim := 80 / 2 // TODO: get this from config
|
|
||||||
theta := float64(10000.0) // TODO: get this from config ropeTheta
|
|
||||||
merge := 2 // Merging factor for spatial dimensions
|
|
||||||
|
|
||||||
// Calculate inverse frequencies for rotation
|
|
||||||
inv := freqInv(ctx, dim, theta)
|
|
||||||
|
|
||||||
// Generate and stack position IDs for height and width dimensions
|
|
||||||
hPos := heightPos(ctx, grid, merge)
|
|
||||||
wPos := widthPos(ctx, grid, merge)
|
|
||||||
// Reshape both and stack them
|
|
||||||
tmp := hPos.Reshape(ctx, 1, hPos.Dim(0))
|
|
||||||
pos := tmp.Stack(ctx, 0, wPos.Reshape(ctx, 1, wPos.Dim(0)))
|
|
||||||
|
|
||||||
// Generate rotary embeddings
|
|
||||||
return rotEmbed(ctx, inv, grid.Width, pos)
|
|
||||||
}
|
|
||||||
|
|
||||||
// freqInv calculates the inverse frequencies for rotary embeddings
|
|
||||||
func freqInv(ctx ml.Context, dim int, theta float64) ml.Tensor {
|
|
||||||
logBase, err := ctx.Input().FromFloatSlice([]float32{float32(math.Log(theta))}, 1)
|
|
||||||
if err != nil {
|
|
||||||
panic(err) // TODO: handle error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create powers divided by dimension (0, 2, 4, ..., dim-2) / dim
|
|
||||||
powers := ctx.Arange(0, float32(dim), 2, ml.DTypeF32)
|
|
||||||
dims, err := ctx.Input().FromFloatSlice([]float32{float32(dim)}, 1)
|
|
||||||
if err != nil {
|
|
||||||
panic(err) // TODO: handle error
|
|
||||||
}
|
|
||||||
powers = powers.Div(ctx, dims)
|
|
||||||
|
|
||||||
// Calculate inverse frequencies: 1 / (theta ^ (powers/dim))
|
|
||||||
dims = powers.Mul(ctx, logBase).Exp(ctx)
|
|
||||||
ones, err := ctx.Input().FromFloatSlice(slices.Repeat([]float32{1.0}, dims.Shape()[0]), dims.Shape()...)
|
|
||||||
if err != nil {
|
|
||||||
panic(err) // TODO: handle error
|
|
||||||
}
|
|
||||||
return ones.Div(ctx, dims)
|
|
||||||
}
|
|
||||||
|
|
||||||
// heightPos generates position IDs for the height dimension
|
|
||||||
func heightPos(ctx ml.Context, grid *Grid, merge int) ml.Tensor {
|
|
||||||
// Create a slice where each row contains the same height value repeated width times
|
|
||||||
data := make([]float32, 0, grid.Height*grid.Width)
|
|
||||||
for i := 0; i < grid.Height; i++ {
|
|
||||||
data = append(data, slices.Repeat([]float32{float32(i)}, grid.Width)...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create pos with shape [height, width]
|
|
||||||
pos, err := ctx.Input().FromFloatSlice(data, grid.Height, grid.Width)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reshape and permute for spatial merging
|
|
||||||
pos = pos.Reshape(
|
|
||||||
ctx,
|
|
||||||
merge,
|
|
||||||
grid.Width/merge,
|
|
||||||
merge,
|
|
||||||
grid.Height/merge,
|
|
||||||
)
|
|
||||||
pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
||||||
|
|
||||||
// Flatten to 1D tensor
|
|
||||||
return pos.Reshape(ctx, pos.Dim(0)*pos.Dim(1)*pos.Dim(2)*pos.Dim(3))
|
|
||||||
}
|
|
||||||
|
|
||||||
// widthPos generates position IDs for the width dimension
|
|
||||||
func widthPos(ctx ml.Context, grid *Grid, merge int) ml.Tensor {
|
|
||||||
// Create a slice containing width values in column-major order
|
|
||||||
data := make([]float32, 0, grid.Height*grid.Width)
|
|
||||||
for i := 0; i < grid.Height; i++ {
|
|
||||||
for j := 0; j < grid.Width; j++ {
|
|
||||||
data = append(data, float32(j))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create pos with shape [width, height]
|
|
||||||
pos, err := ctx.Input().FromFloatSlice(data, grid.Width, grid.Height)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reshape and permute for spatial merging
|
|
||||||
pos = pos.Reshape(
|
|
||||||
ctx,
|
|
||||||
merge,
|
|
||||||
grid.Width/merge,
|
|
||||||
merge,
|
|
||||||
grid.Height/merge,
|
|
||||||
)
|
|
||||||
pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
||||||
|
|
||||||
// Flatten to 1D tensor
|
|
||||||
return pos.Reshape(ctx, pos.Dim(0)*pos.Dim(1)*pos.Dim(2)*pos.Dim(3))
|
|
||||||
}
|
|
||||||
|
|
||||||
// rotEmbed generates rotary embeddings using inverse frequencies and position IDs
|
|
||||||
func rotEmbed(ctx ml.Context, freqInv ml.Tensor, maxSize int, pos ml.Tensor) ml.Tensor {
|
|
||||||
// Create sequence tensor [0, 1, 2, ..., maxGridSize-1]
|
|
||||||
seq := ctx.Arange(0, float32(maxSize), 1, ml.DTypeF32)
|
|
||||||
|
|
||||||
// Reshape for matrix multiplication and calculate outer product
|
|
||||||
outer := freqInv.Reshape(ctx, 1, freqInv.Shape()[0]).Mulmat(ctx, seq.Reshape(ctx, 1, maxSize))
|
|
||||||
|
|
||||||
// Flatten position IDs and use as indices to select rows from outer product
|
|
||||||
return outer.Rows(ctx, pos.Reshape(ctx, pos.Dim(0)*pos.Dim(1)))
|
|
||||||
|
|
||||||
// TODO: index position IDs and flatten
|
|
||||||
}
|
|
||||||
|
|
||||||
// VisionModel implements the Qwen vision model
|
// VisionModel implements the Qwen vision model
|
||||||
type VisionModel struct {
|
type VisionModel struct {
|
||||||
PatchEmbedding *PatchEmbedding
|
PatchEmbedding *PatchEmbedding
|
||||||
@ -303,11 +188,8 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
|||||||
m.patchSize, // patch size, e.g., 14
|
m.patchSize, // patch size, e.g., 14
|
||||||
)
|
)
|
||||||
|
|
||||||
rope(ctx, grid)
|
// TODO: working here
|
||||||
|
m.rotaryEmbedding(ctx, grid)
|
||||||
// spatialMergeSize := 2 // TODO: get this from config
|
|
||||||
// // Create the position IDs tensor with correct dimensions
|
|
||||||
// positions := []int32{}
|
|
||||||
|
|
||||||
// // Apply encoder layers
|
// // Apply encoder layers
|
||||||
// for _, layer := range m.Layers {
|
// for _, layer := range m.Layers {
|
||||||
@ -318,6 +200,62 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
|||||||
return hiddenStates
|
return hiddenStates
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// rotaryEmbedding generates rotary position embeddings for attention mechanisms
|
||||||
|
// This implements rotary embeddings using spatial merging patterns for grid-based
|
||||||
|
// vision transformers
|
||||||
|
func (m *VisionModel) rotaryEmbedding(ctx ml.Context, grid *Grid) (ml.Tensor, ml.Tensor) {
|
||||||
|
// Configuration parameters
|
||||||
|
dim := 80 / 2 // Head dimension divided by 2
|
||||||
|
freq := dim / 2 // Frequency dimension (half of head dimension)
|
||||||
|
theta := 10000.0 // Base for frequency scaling
|
||||||
|
merge := 2 // Spatial merge size for rearranging coordinates
|
||||||
|
|
||||||
|
// Create frequency patterns for position encoding
|
||||||
|
// These are scaled position values based on frequency
|
||||||
|
// In PyTorch: Similar to inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2) / dim))
|
||||||
|
freqVals := make([]float32, freq*grid.Width)
|
||||||
|
for i := range grid.Width {
|
||||||
|
for j := range freq {
|
||||||
|
freqVals[i*freq+j] = float32(i) / float32(math.Pow(theta, float64(j*2)/float64(dim)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
freqs, err := ctx.Input().FromFloatSlice(freqVals, freq, grid.Width)
|
||||||
|
if err != nil {
|
||||||
|
panic(err) // TODO: handle error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create position coordinates (y,x pairs) for the grid
|
||||||
|
// In PyTorch: Equivalent to generating position ids with torch.arange()
|
||||||
|
coords := make([]int32, 0, grid.Height*grid.Width*2)
|
||||||
|
for y := range grid.Height {
|
||||||
|
for x := range grid.Width {
|
||||||
|
coords = append(coords, int32(y), int32(x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pos, err := ctx.Input().FromIntSlice(coords, 2, grid.Width, grid.Height)
|
||||||
|
if err != nil {
|
||||||
|
panic(err) // TODO: handle error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reshape and permute positions to match spatial merging pattern
|
||||||
|
// This rearranges positions to group spatially related coordinates
|
||||||
|
pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge)
|
||||||
|
pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
|
pos = pos.Reshape(ctx, 2, merge, merge, grid.Width/merge*grid.Height/merge)
|
||||||
|
pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
|
pos = pos.Reshape(ctx, 2*merge*merge*grid.Width/merge*grid.Height/merge)
|
||||||
|
|
||||||
|
// Use position indices to look up corresponding frequency values
|
||||||
|
out := freqs.Rows(ctx, pos)
|
||||||
|
out = out.Reshape(ctx, out.Dim(0)*2, out.Dim(1)/2)
|
||||||
|
|
||||||
|
fmt.Println("out", out.Shape())
|
||||||
|
fmt.Println(ml.Dump(ctx, out))
|
||||||
|
|
||||||
|
// TODO: return cos and sin tensors for rotary embedding
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
// newVisionModel creates a new instance of the Qwen vision model
|
// newVisionModel creates a new instance of the Qwen vision model
|
||||||
func newVisionModel(c fs.Config) *VisionModel {
|
func newVisionModel(c fs.Config) *VisionModel {
|
||||||
patchSize := int(c.Uint("vision.patch_size", 14))
|
patchSize := int(c.Uint("vision.patch_size", 14))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user