grid refactor
This commit is contained in:
parent
963531215e
commit
eedc969c35
@ -64,9 +64,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
|
||||
type imageFeatures struct {
|
||||
Tensor ml.Tensor
|
||||
GridT int
|
||||
GridH int
|
||||
GridW int
|
||||
Grid *Grid
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
||||
@ -79,7 +77,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f32s, gridT, gridH, gridW, err := m.ImageProcessor.ProcessImage(image)
|
||||
f32s, grid, err := m.ImageProcessor.ProcessImage(image)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -87,7 +85,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
||||
// Calculate tensor dimensions
|
||||
patchDim := m.ImageProcessor.numChannels * m.ImageProcessor.temporalPatchSize *
|
||||
m.ImageProcessor.patchSize * m.ImageProcessor.patchSize
|
||||
numPatches := gridT * gridH * gridW
|
||||
numPatches := grid.Temporal * grid.Height * grid.Width
|
||||
|
||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches)
|
||||
if err != nil {
|
||||
@ -99,9 +97,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
||||
|
||||
return &imageFeatures{
|
||||
Tensor: visionOutputs,
|
||||
GridT: gridT,
|
||||
GridH: gridH,
|
||||
GridW: gridW,
|
||||
Grid: grid,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -125,14 +121,9 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
// This is an image token with multimodal data
|
||||
features := inp.Multimodal.(*imageFeatures)
|
||||
|
||||
// Get grid dimensions from the features
|
||||
gridT := features.GridT
|
||||
gridH := features.GridH
|
||||
gridW := features.GridW
|
||||
|
||||
// Calculate tokens per grid based on grid dimensions
|
||||
mergeLength := mergeSize * mergeSize
|
||||
gridProduct := gridT * gridH * gridW
|
||||
gridProduct := features.Grid.Temporal * features.Grid.Height * features.Grid.Width
|
||||
tokensPerGrid := gridProduct / mergeLength
|
||||
|
||||
// First add the vision start token
|
||||
|
@ -177,7 +177,7 @@ type VisionModel struct {
|
||||
}
|
||||
|
||||
// Forward computes the vision model for an input tensor
|
||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) ml.Tensor {
|
||||
// Calculate position IDs for 2D RoPE
|
||||
numPatchesH := pixelValues.Dim(0) / m.patchSize
|
||||
numPatchesW := pixelValues.Dim(1) / m.patchSize
|
||||
@ -193,14 +193,12 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||
)
|
||||
|
||||
// Create position IDs - for Qwen2VL mRoPE we need 4 values per position
|
||||
// The format needed is specified in the C++ code as "mrope expecting 4 position ids per token"
|
||||
positions := make([]int32, numPatches*4)
|
||||
|
||||
for h := 0; h < numPatchesH; h++ {
|
||||
for w := 0; w < numPatchesW; w++ {
|
||||
idx := h*numPatchesW + w
|
||||
// For each position, store both h and w coordinates twice
|
||||
// This matches the pattern seen in the C++ implementation
|
||||
positions[idx*4] = int32(h) // y coordinate
|
||||
positions[idx*4+1] = int32(w) // x coordinate
|
||||
positions[idx*4+2] = int32(h) // y coordinate (repeated)
|
||||
|
@ -77,7 +77,13 @@ func (p *ImageProcessor) SmartResize(height, width int) (int, int) {
|
||||
return hBar, wBar
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, int, int, int, error) {
|
||||
type Grid struct {
|
||||
Height int
|
||||
Width int
|
||||
Temporal int
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error) {
|
||||
origWidth := img.Bounds().Dx()
|
||||
origHeight := img.Bounds().Dy()
|
||||
|
||||
@ -96,27 +102,29 @@ func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, int, int, int
|
||||
)
|
||||
|
||||
// Calculate grid dimensions
|
||||
gridH := resizedHeight / p.patchSize
|
||||
gridW := resizedWidth / p.patchSize
|
||||
gridT := 1 // For single images, temporal dimension is 1
|
||||
grid := &Grid{
|
||||
Height: resizedHeight / p.patchSize,
|
||||
Width: resizedWidth / p.patchSize,
|
||||
Temporal: 1, // For single images, temporal dimension is 1
|
||||
}
|
||||
|
||||
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, gridH, gridW, gridT)
|
||||
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
|
||||
if err != nil {
|
||||
return nil, 0, 0, 0, fmt.Errorf("failed to create patches: %v", err)
|
||||
return nil, nil, fmt.Errorf("failed to create patches: %v", err)
|
||||
}
|
||||
|
||||
// Return patches and grid dimensions
|
||||
return patches, gridT, gridH, gridW, nil
|
||||
return patches, grid, nil
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) createPatches(pixels []float32, height, width, gridH, gridW, gridT int) ([]float32, error) {
|
||||
func (p *ImageProcessor) createPatches(pixels []float32, height, width int, grid *Grid) ([]float32, error) {
|
||||
channels := p.numChannels
|
||||
patchSize := p.patchSize
|
||||
mergeSize := p.mergeSize
|
||||
temporalPatchSize := p.temporalPatchSize
|
||||
|
||||
// Calculate output dimensions
|
||||
numPatches := gridT * gridH * gridW
|
||||
numPatches := grid.Temporal * grid.Height * grid.Width
|
||||
patchDim := channels * temporalPatchSize * patchSize * patchSize
|
||||
|
||||
// Create output tensor
|
||||
@ -126,10 +134,10 @@ func (p *ImageProcessor) createPatches(pixels []float32, height, width, gridH, g
|
||||
// in the format expected by the forward pass
|
||||
patchIndex := 0
|
||||
|
||||
for t := 0; t < gridT; t++ {
|
||||
for t := 0; t < grid.Temporal; t++ {
|
||||
// For each patch in the grid
|
||||
for h := 0; h < gridH; h += mergeSize {
|
||||
for w := 0; w < gridW; w += mergeSize {
|
||||
for h := 0; h < grid.Height; h += mergeSize {
|
||||
for w := 0; w < grid.Width; w += mergeSize {
|
||||
// Handle the 2x2 merged patches
|
||||
for mh := 0; mh < mergeSize; mh++ {
|
||||
for mw := 0; mw < mergeSize; mw++ {
|
||||
|
Loading…
x
Reference in New Issue
Block a user