grid refactor
This commit is contained in:
parent
963531215e
commit
eedc969c35
@ -64,9 +64,7 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
|
|
||||||
type imageFeatures struct {
|
type imageFeatures struct {
|
||||||
Tensor ml.Tensor
|
Tensor ml.Tensor
|
||||||
GridT int
|
Grid *Grid
|
||||||
GridH int
|
|
||||||
GridW int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
||||||
@ -79,7 +77,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
f32s, gridT, gridH, gridW, err := m.ImageProcessor.ProcessImage(image)
|
f32s, grid, err := m.ImageProcessor.ProcessImage(image)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -87,7 +85,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
|||||||
// Calculate tensor dimensions
|
// Calculate tensor dimensions
|
||||||
patchDim := m.ImageProcessor.numChannels * m.ImageProcessor.temporalPatchSize *
|
patchDim := m.ImageProcessor.numChannels * m.ImageProcessor.temporalPatchSize *
|
||||||
m.ImageProcessor.patchSize * m.ImageProcessor.patchSize
|
m.ImageProcessor.patchSize * m.ImageProcessor.patchSize
|
||||||
numPatches := gridT * gridH * gridW
|
numPatches := grid.Temporal * grid.Height * grid.Width
|
||||||
|
|
||||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches)
|
pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -99,9 +97,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
|||||||
|
|
||||||
return &imageFeatures{
|
return &imageFeatures{
|
||||||
Tensor: visionOutputs,
|
Tensor: visionOutputs,
|
||||||
GridT: gridT,
|
Grid: grid,
|
||||||
GridH: gridH,
|
|
||||||
GridW: gridW,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -125,14 +121,9 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||||||
// This is an image token with multimodal data
|
// This is an image token with multimodal data
|
||||||
features := inp.Multimodal.(*imageFeatures)
|
features := inp.Multimodal.(*imageFeatures)
|
||||||
|
|
||||||
// Get grid dimensions from the features
|
|
||||||
gridT := features.GridT
|
|
||||||
gridH := features.GridH
|
|
||||||
gridW := features.GridW
|
|
||||||
|
|
||||||
// Calculate tokens per grid based on grid dimensions
|
// Calculate tokens per grid based on grid dimensions
|
||||||
mergeLength := mergeSize * mergeSize
|
mergeLength := mergeSize * mergeSize
|
||||||
gridProduct := gridT * gridH * gridW
|
gridProduct := features.Grid.Temporal * features.Grid.Height * features.Grid.Width
|
||||||
tokensPerGrid := gridProduct / mergeLength
|
tokensPerGrid := gridProduct / mergeLength
|
||||||
|
|
||||||
// First add the vision start token
|
// First add the vision start token
|
||||||
|
@ -177,7 +177,7 @@ type VisionModel struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Forward computes the vision model for an input tensor
|
// Forward computes the vision model for an input tensor
|
||||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) ml.Tensor {
|
||||||
// Calculate position IDs for 2D RoPE
|
// Calculate position IDs for 2D RoPE
|
||||||
numPatchesH := pixelValues.Dim(0) / m.patchSize
|
numPatchesH := pixelValues.Dim(0) / m.patchSize
|
||||||
numPatchesW := pixelValues.Dim(1) / m.patchSize
|
numPatchesW := pixelValues.Dim(1) / m.patchSize
|
||||||
@ -193,14 +193,12 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Create position IDs - for Qwen2VL mRoPE we need 4 values per position
|
// Create position IDs - for Qwen2VL mRoPE we need 4 values per position
|
||||||
// The format needed is specified in the C++ code as "mrope expecting 4 position ids per token"
|
|
||||||
positions := make([]int32, numPatches*4)
|
positions := make([]int32, numPatches*4)
|
||||||
|
|
||||||
for h := 0; h < numPatchesH; h++ {
|
for h := 0; h < numPatchesH; h++ {
|
||||||
for w := 0; w < numPatchesW; w++ {
|
for w := 0; w < numPatchesW; w++ {
|
||||||
idx := h*numPatchesW + w
|
idx := h*numPatchesW + w
|
||||||
// For each position, store both h and w coordinates twice
|
// For each position, store both h and w coordinates twice
|
||||||
// This matches the pattern seen in the C++ implementation
|
|
||||||
positions[idx*4] = int32(h) // y coordinate
|
positions[idx*4] = int32(h) // y coordinate
|
||||||
positions[idx*4+1] = int32(w) // x coordinate
|
positions[idx*4+1] = int32(w) // x coordinate
|
||||||
positions[idx*4+2] = int32(h) // y coordinate (repeated)
|
positions[idx*4+2] = int32(h) // y coordinate (repeated)
|
||||||
|
@ -77,7 +77,13 @@ func (p *ImageProcessor) SmartResize(height, width int) (int, int) {
|
|||||||
return hBar, wBar
|
return hBar, wBar
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, int, int, int, error) {
|
type Grid struct {
|
||||||
|
Height int
|
||||||
|
Width int
|
||||||
|
Temporal int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error) {
|
||||||
origWidth := img.Bounds().Dx()
|
origWidth := img.Bounds().Dx()
|
||||||
origHeight := img.Bounds().Dy()
|
origHeight := img.Bounds().Dy()
|
||||||
|
|
||||||
@ -96,27 +102,29 @@ func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, int, int, int
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Calculate grid dimensions
|
// Calculate grid dimensions
|
||||||
gridH := resizedHeight / p.patchSize
|
grid := &Grid{
|
||||||
gridW := resizedWidth / p.patchSize
|
Height: resizedHeight / p.patchSize,
|
||||||
gridT := 1 // For single images, temporal dimension is 1
|
Width: resizedWidth / p.patchSize,
|
||||||
|
Temporal: 1, // For single images, temporal dimension is 1
|
||||||
|
}
|
||||||
|
|
||||||
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, gridH, gridW, gridT)
|
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, 0, 0, fmt.Errorf("failed to create patches: %v", err)
|
return nil, nil, fmt.Errorf("failed to create patches: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return patches and grid dimensions
|
// Return patches and grid dimensions
|
||||||
return patches, gridT, gridH, gridW, nil
|
return patches, grid, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ImageProcessor) createPatches(pixels []float32, height, width, gridH, gridW, gridT int) ([]float32, error) {
|
func (p *ImageProcessor) createPatches(pixels []float32, height, width int, grid *Grid) ([]float32, error) {
|
||||||
channels := p.numChannels
|
channels := p.numChannels
|
||||||
patchSize := p.patchSize
|
patchSize := p.patchSize
|
||||||
mergeSize := p.mergeSize
|
mergeSize := p.mergeSize
|
||||||
temporalPatchSize := p.temporalPatchSize
|
temporalPatchSize := p.temporalPatchSize
|
||||||
|
|
||||||
// Calculate output dimensions
|
// Calculate output dimensions
|
||||||
numPatches := gridT * gridH * gridW
|
numPatches := grid.Temporal * grid.Height * grid.Width
|
||||||
patchDim := channels * temporalPatchSize * patchSize * patchSize
|
patchDim := channels * temporalPatchSize * patchSize * patchSize
|
||||||
|
|
||||||
// Create output tensor
|
// Create output tensor
|
||||||
@ -126,10 +134,10 @@ func (p *ImageProcessor) createPatches(pixels []float32, height, width, gridH, g
|
|||||||
// in the format expected by the forward pass
|
// in the format expected by the forward pass
|
||||||
patchIndex := 0
|
patchIndex := 0
|
||||||
|
|
||||||
for t := 0; t < gridT; t++ {
|
for t := 0; t < grid.Temporal; t++ {
|
||||||
// For each patch in the grid
|
// For each patch in the grid
|
||||||
for h := 0; h < gridH; h += mergeSize {
|
for h := 0; h < grid.Height; h += mergeSize {
|
||||||
for w := 0; w < gridW; w += mergeSize {
|
for w := 0; w < grid.Width; w += mergeSize {
|
||||||
// Handle the 2x2 merged patches
|
// Handle the 2x2 merged patches
|
||||||
for mh := 0; mh < mergeSize; mh++ {
|
for mh := 0; mh < mergeSize; mh++ {
|
||||||
for mw := 0; mw < mergeSize; mw++ {
|
for mw := 0; mw < mergeSize; mw++ {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user