window index

This commit is contained in:
Michael Yang 2025-04-29 14:46:23 -07:00 committed by Bruce MacDonald
parent ee869f35e4
commit fb3c16f2a2
2 changed files with 61 additions and 3 deletions

View File

@ -18,8 +18,9 @@ type qwen25VLModel struct {
InChannels uint32 `json:"in_chans"`
NumHeads uint32 `json:"num_heads"`
PatchSize uint32 `json:"patch_size"`
SpatialMergeSize uint32 `json:"spatial_merge_size"` // TODO: is this set?
SpatialMergeSize uint32 `json:"spatial_merge_size"`
SpatialPatchSize uint32 `json:"spatial_patch_size"`
WindowSize uint32 `json:"window_size"`
RopeTheta float32 `json:"rope_theta"`
} `json:"vision_config"`
}

View File

@ -100,6 +100,9 @@ type VisionModelOptions struct {
eps float32
ropeTheta float32
outHiddenSize int
spatialMergeSize int
spatialPatchSize int
windowSize int
}
type PatchEmbedding struct {
@ -184,6 +187,20 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
)
positionEmbedding := m.positionalEmbedding(ctx, grid)
windowIndex := m.windowIndex(ctx, grid)
spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*spatialMergeUnit, hiddenStates.Dim(1)/spatialMergeUnit)
hiddenStates = hiddenStates.Rows(ctx, windowIndex)
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)/spatialMergeUnit, hiddenStates.Dim(1)*spatialMergeUnit)
positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)*spatialMergeUnit, positionEmbedding.Dim(1)/spatialMergeUnit)
positionEmbedding = positionEmbedding.Rows(ctx, windowIndex)
positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)/spatialMergeUnit, positionEmbedding.Dim(1)*spatialMergeUnit)
positionEmbedding = positionEmbedding.Concat(ctx, positionEmbedding, 0)
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
@ -196,6 +213,45 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
return m.PatchMerger.Forward(ctx, hiddenStates, m.eps)
}
func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor {
vitMergerWindowSize := m.windowSize / m.spatialMergeSize / m.patchSize
llmGridH := grid.Height / m.spatialMergeSize
llmGridW := grid.Width / m.spatialMergeSize
// Calculate window parameters
numWindowsH := int(math.Ceil(float64(llmGridH) / float64(vitMergerWindowSize)))
numWindowsW := int(math.Ceil(float64(llmGridW) / float64(vitMergerWindowSize)))
// Initialize index_new slice
var index []int32
// Process each window without padding
for wh := range numWindowsH {
for ww := range numWindowsW {
// Calculate window boundaries
hStart := wh * vitMergerWindowSize
wStart := ww * vitMergerWindowSize
hEnd := min(hStart+vitMergerWindowSize, llmGridH)
wEnd := min(wStart+vitMergerWindowSize, llmGridW)
// Collect indices for this window
for h := hStart; h < hEnd; h++ {
for w := wStart; w < wEnd; w++ {
index = append(index, int32(h*llmGridW+w))
}
}
}
}
t, err := ctx.Input().FromIntSlice(index, len(index))
if err != nil {
panic(err)
}
return t
}
// positionalEmbedding generates rotary position embeddings for attention mechanisms
// This implements rotary embeddings using spatial merging patterns for grid-based
// vision transformers
@ -245,8 +301,6 @@ func (m *VisionModel) positionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor
// Use position indices to look up corresponding frequency values
positionalEmbedding := freqs.Rows(ctx, pos)
positionalEmbedding = positionalEmbedding.Reshape(ctx, positionalEmbedding.Dim(0)*2, positionalEmbedding.Dim(1)/2)
positionalEmbedding = positionalEmbedding.Concat(ctx, positionalEmbedding, 0)
return positionalEmbedding
}
@ -271,6 +325,9 @@ func newVisionModel(c fs.Config) *VisionModel {
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6),
ropeTheta: ropeTheta,
outHiddenSize: outHiddenSize,
spatialMergeSize: int(c.Uint("vision.spatial_merge_size", 2)),
spatialPatchSize: int(c.Uint("vision.spatial_patch_size", 2)),
windowSize: int(c.Uint("vision.window_size", 112)),
},
}
}