grid refactor

This commit is contained in:
Bruce MacDonald 2025-04-21 15:06:13 -07:00
parent 963531215e
commit eedc969c35
3 changed files with 26 additions and 29 deletions

View File

@ -64,9 +64,7 @@ func New(c fs.Config) (model.Model, error) {
type imageFeatures struct { type imageFeatures struct {
Tensor ml.Tensor Tensor ml.Tensor
GridT int Grid *Grid
GridH int
GridW int
} }
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
@ -79,7 +77,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
return nil, err return nil, err
} }
f32s, gridT, gridH, gridW, err := m.ImageProcessor.ProcessImage(image) f32s, grid, err := m.ImageProcessor.ProcessImage(image)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -87,7 +85,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
// Calculate tensor dimensions // Calculate tensor dimensions
patchDim := m.ImageProcessor.numChannels * m.ImageProcessor.temporalPatchSize * patchDim := m.ImageProcessor.numChannels * m.ImageProcessor.temporalPatchSize *
m.ImageProcessor.patchSize * m.ImageProcessor.patchSize m.ImageProcessor.patchSize * m.ImageProcessor.patchSize
numPatches := gridT * gridH * gridW numPatches := grid.Temporal * grid.Height * grid.Width
pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches) pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches)
if err != nil { if err != nil {
@ -99,9 +97,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
return &imageFeatures{ return &imageFeatures{
Tensor: visionOutputs, Tensor: visionOutputs,
GridT: gridT, Grid: grid,
GridH: gridH,
GridW: gridW,
}, nil }, nil
} }
@ -125,14 +121,9 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
// This is an image token with multimodal data // This is an image token with multimodal data
features := inp.Multimodal.(*imageFeatures) features := inp.Multimodal.(*imageFeatures)
// Get grid dimensions from the features
gridT := features.GridT
gridH := features.GridH
gridW := features.GridW
// Calculate tokens per grid based on grid dimensions // Calculate tokens per grid based on grid dimensions
mergeLength := mergeSize * mergeSize mergeLength := mergeSize * mergeSize
gridProduct := gridT * gridH * gridW gridProduct := features.Grid.Temporal * features.Grid.Height * features.Grid.Width
tokensPerGrid := gridProduct / mergeLength tokensPerGrid := gridProduct / mergeLength
// First add the vision start token // First add the vision start token

View File

@ -177,7 +177,7 @@ type VisionModel struct {
} }
// Forward computes the vision model for an input tensor // Forward computes the vision model for an input tensor
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) ml.Tensor {
// Calculate position IDs for 2D RoPE // Calculate position IDs for 2D RoPE
numPatchesH := pixelValues.Dim(0) / m.patchSize numPatchesH := pixelValues.Dim(0) / m.patchSize
numPatchesW := pixelValues.Dim(1) / m.patchSize numPatchesW := pixelValues.Dim(1) / m.patchSize
@ -193,14 +193,12 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
) )
// Create position IDs - for Qwen2VL mRoPE we need 4 values per position // Create position IDs - for Qwen2VL mRoPE we need 4 values per position
// The format needed is specified in the C++ code as "mrope expecting 4 position ids per token"
positions := make([]int32, numPatches*4) positions := make([]int32, numPatches*4)
for h := 0; h < numPatchesH; h++ { for h := 0; h < numPatchesH; h++ {
for w := 0; w < numPatchesW; w++ { for w := 0; w < numPatchesW; w++ {
idx := h*numPatchesW + w idx := h*numPatchesW + w
// For each position, store both h and w coordinates twice // For each position, store both h and w coordinates twice
// This matches the pattern seen in the C++ implementation
positions[idx*4] = int32(h) // y coordinate positions[idx*4] = int32(h) // y coordinate
positions[idx*4+1] = int32(w) // x coordinate positions[idx*4+1] = int32(w) // x coordinate
positions[idx*4+2] = int32(h) // y coordinate (repeated) positions[idx*4+2] = int32(h) // y coordinate (repeated)

View File

@ -77,7 +77,13 @@ func (p *ImageProcessor) SmartResize(height, width int) (int, int) {
return hBar, wBar return hBar, wBar
} }
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, int, int, int, error) { type Grid struct {
Height int
Width int
Temporal int
}
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error) {
origWidth := img.Bounds().Dx() origWidth := img.Bounds().Dx()
origHeight := img.Bounds().Dy() origHeight := img.Bounds().Dy()
@ -96,27 +102,29 @@ func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, int, int, int
) )
// Calculate grid dimensions // Calculate grid dimensions
gridH := resizedHeight / p.patchSize grid := &Grid{
gridW := resizedWidth / p.patchSize Height: resizedHeight / p.patchSize,
gridT := 1 // For single images, temporal dimension is 1 Width: resizedWidth / p.patchSize,
Temporal: 1, // For single images, temporal dimension is 1
}
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, gridH, gridW, gridT) patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
if err != nil { if err != nil {
return nil, 0, 0, 0, fmt.Errorf("failed to create patches: %v", err) return nil, nil, fmt.Errorf("failed to create patches: %v", err)
} }
// Return patches and grid dimensions // Return patches and grid dimensions
return patches, gridT, gridH, gridW, nil return patches, grid, nil
} }
func (p *ImageProcessor) createPatches(pixels []float32, height, width, gridH, gridW, gridT int) ([]float32, error) { func (p *ImageProcessor) createPatches(pixels []float32, height, width int, grid *Grid) ([]float32, error) {
channels := p.numChannels channels := p.numChannels
patchSize := p.patchSize patchSize := p.patchSize
mergeSize := p.mergeSize mergeSize := p.mergeSize
temporalPatchSize := p.temporalPatchSize temporalPatchSize := p.temporalPatchSize
// Calculate output dimensions // Calculate output dimensions
numPatches := gridT * gridH * gridW numPatches := grid.Temporal * grid.Height * grid.Width
patchDim := channels * temporalPatchSize * patchSize * patchSize patchDim := channels * temporalPatchSize * patchSize * patchSize
// Create output tensor // Create output tensor
@ -126,10 +134,10 @@ func (p *ImageProcessor) createPatches(pixels []float32, height, width, gridH, g
// in the format expected by the forward pass // in the format expected by the forward pass
patchIndex := 0 patchIndex := 0
for t := 0; t < gridT; t++ { for t := 0; t < grid.Temporal; t++ {
// For each patch in the grid // For each patch in the grid
for h := 0; h < gridH; h += mergeSize { for h := 0; h < grid.Height; h += mergeSize {
for w := 0; w < gridW; w += mergeSize { for w := 0; w < grid.Width; w += mergeSize {
// Handle the 2x2 merged patches // Handle the 2x2 merged patches
for mh := 0; mh < mergeSize; mh++ { for mh := 0; mh < mergeSize; mh++ {
for mw := 0; mw < mergeSize; mw++ { for mw := 0; mw < mergeSize; mw++ {