grid refactor

This commit is contained in:
Bruce MacDonald 2025-04-21 15:06:13 -07:00
parent 963531215e
commit eedc969c35
3 changed files with 26 additions and 29 deletions

View File

@ -64,9 +64,7 @@ func New(c fs.Config) (model.Model, error) {
type imageFeatures struct {
Tensor ml.Tensor
GridT int
GridH int
GridW int
Grid *Grid
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
@ -79,7 +77,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
return nil, err
}
f32s, gridT, gridH, gridW, err := m.ImageProcessor.ProcessImage(image)
f32s, grid, err := m.ImageProcessor.ProcessImage(image)
if err != nil {
return nil, err
}
@ -87,7 +85,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
// Calculate tensor dimensions
patchDim := m.ImageProcessor.numChannels * m.ImageProcessor.temporalPatchSize *
m.ImageProcessor.patchSize * m.ImageProcessor.patchSize
numPatches := gridT * gridH * gridW
numPatches := grid.Temporal * grid.Height * grid.Width
pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches)
if err != nil {
@ -99,9 +97,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
return &imageFeatures{
Tensor: visionOutputs,
GridT: gridT,
GridH: gridH,
GridW: gridW,
Grid: grid,
}, nil
}
@ -125,14 +121,9 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
// This is an image token with multimodal data
features := inp.Multimodal.(*imageFeatures)
// Get grid dimensions from the features
gridT := features.GridT
gridH := features.GridH
gridW := features.GridW
// Calculate tokens per grid based on grid dimensions
mergeLength := mergeSize * mergeSize
gridProduct := gridT * gridH * gridW
gridProduct := features.Grid.Temporal * features.Grid.Height * features.Grid.Width
tokensPerGrid := gridProduct / mergeLength
// First add the vision start token

View File

@ -177,7 +177,7 @@ type VisionModel struct {
}
// Forward computes the vision model for an input tensor
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) ml.Tensor {
// Calculate position IDs for 2D RoPE
numPatchesH := pixelValues.Dim(0) / m.patchSize
numPatchesW := pixelValues.Dim(1) / m.patchSize
@ -193,14 +193,12 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
)
// Create position IDs - for Qwen2VL mRoPE we need 4 values per position
// The format needed is specified in the C++ code as "mrope expecting 4 position ids per token"
positions := make([]int32, numPatches*4)
for h := 0; h < numPatchesH; h++ {
for w := 0; w < numPatchesW; w++ {
idx := h*numPatchesW + w
// For each position, store both h and w coordinates twice
// This matches the pattern seen in the C++ implementation
positions[idx*4] = int32(h) // y coordinate
positions[idx*4+1] = int32(w) // x coordinate
positions[idx*4+2] = int32(h) // y coordinate (repeated)

View File

@ -77,7 +77,13 @@ func (p *ImageProcessor) SmartResize(height, width int) (int, int) {
return hBar, wBar
}
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, int, int, int, error) {
type Grid struct {
Height int
Width int
Temporal int
}
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error) {
origWidth := img.Bounds().Dx()
origHeight := img.Bounds().Dy()
@ -96,27 +102,29 @@ func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, int, int, int
)
// Calculate grid dimensions
gridH := resizedHeight / p.patchSize
gridW := resizedWidth / p.patchSize
gridT := 1 // For single images, temporal dimension is 1
grid := &Grid{
Height: resizedHeight / p.patchSize,
Width: resizedWidth / p.patchSize,
Temporal: 1, // For single images, temporal dimension is 1
}
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, gridH, gridW, gridT)
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
if err != nil {
return nil, 0, 0, 0, fmt.Errorf("failed to create patches: %v", err)
return nil, nil, fmt.Errorf("failed to create patches: %v", err)
}
// Return patches and grid dimensions
return patches, gridT, gridH, gridW, nil
return patches, grid, nil
}
func (p *ImageProcessor) createPatches(pixels []float32, height, width, gridH, gridW, gridT int) ([]float32, error) {
func (p *ImageProcessor) createPatches(pixels []float32, height, width int, grid *Grid) ([]float32, error) {
channels := p.numChannels
patchSize := p.patchSize
mergeSize := p.mergeSize
temporalPatchSize := p.temporalPatchSize
// Calculate output dimensions
numPatches := gridT * gridH * gridW
numPatches := grid.Temporal * grid.Height * grid.Width
patchDim := channels * temporalPatchSize * patchSize * patchSize
// Create output tensor
@ -126,10 +134,10 @@ func (p *ImageProcessor) createPatches(pixels []float32, height, width, gridH, g
// in the format expected by the forward pass
patchIndex := 0
for t := 0; t < gridT; t++ {
for t := 0; t < grid.Temporal; t++ {
// For each patch in the grid
for h := 0; h < gridH; h += mergeSize {
for w := 0; w < gridW; w += mergeSize {
for h := 0; h < grid.Height; h += mergeSize {
for w := 0; w < grid.Width; w += mergeSize {
// Handle the 2x2 merged patches
for mh := 0; mh < mergeSize; mh++ {
for mw := 0; mw < mergeSize; mw++ {