diff --git a/convert/convert_qwen25vl.go b/convert/convert_qwen25vl.go index 4d8d248ca..2008f2d16 100644 --- a/convert/convert_qwen25vl.go +++ b/convert/convert_qwen25vl.go @@ -18,8 +18,9 @@ type qwen25VLModel struct { InChannels uint32 `json:"in_chans"` NumHeads uint32 `json:"num_heads"` PatchSize uint32 `json:"patch_size"` - SpatialMergeSize uint32 `json:"spatial_merge_size"` // TODO: is this set? + SpatialMergeSize uint32 `json:"spatial_merge_size"` SpatialPatchSize uint32 `json:"spatial_patch_size"` + WindowSize uint32 `json:"window_size"` RopeTheta float32 `json:"rope_theta"` } `json:"vision_config"` } diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go index 65980d166..f25d3e1aa 100644 --- a/model/models/qwen25vl/model_vision.go +++ b/model/models/qwen25vl/model_vision.go @@ -100,6 +100,9 @@ type VisionModelOptions struct { eps float32 ropeTheta float32 outHiddenSize int + spatialMergeSize int + spatialPatchSize int + windowSize int } type PatchEmbedding struct { @@ -184,6 +187,20 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) ) positionEmbedding := m.positionalEmbedding(ctx, grid) + + windowIndex := m.windowIndex(ctx, grid) + + spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize + + hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*spatialMergeUnit, hiddenStates.Dim(1)/spatialMergeUnit) + hiddenStates = hiddenStates.Rows(ctx, windowIndex) + hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)/spatialMergeUnit, hiddenStates.Dim(1)*spatialMergeUnit) + + positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)*spatialMergeUnit, positionEmbedding.Dim(1)/spatialMergeUnit) + positionEmbedding = positionEmbedding.Rows(ctx, windowIndex) + positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)/spatialMergeUnit, positionEmbedding.Dim(1)*spatialMergeUnit) + positionEmbedding = positionEmbedding.Concat(ctx, positionEmbedding, 0) + cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx) cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1)) sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1)) @@ -196,6 +213,45 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) return m.PatchMerger.Forward(ctx, hiddenStates, m.eps) } +func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor { + vitMergerWindowSize := m.windowSize / m.spatialMergeSize / m.patchSize + + llmGridH := grid.Height / m.spatialMergeSize + llmGridW := grid.Width / m.spatialMergeSize + + // Calculate window parameters + numWindowsH := int(math.Ceil(float64(llmGridH) / float64(vitMergerWindowSize))) + numWindowsW := int(math.Ceil(float64(llmGridW) / float64(vitMergerWindowSize))) + + // Initialize index_new slice + var index []int32 + + // Process each window without padding + for wh := range numWindowsH { + for ww := range numWindowsW { + // Calculate window boundaries + hStart := wh * vitMergerWindowSize + wStart := ww * vitMergerWindowSize + hEnd := min(hStart+vitMergerWindowSize, llmGridH) + wEnd := min(wStart+vitMergerWindowSize, llmGridW) + + // Collect indices for this window + for h := hStart; h < hEnd; h++ { + for w := wStart; w < wEnd; w++ { + index = append(index, int32(h*llmGridW+w)) + } + } + } + } + + t, err := ctx.Input().FromIntSlice(index, len(index)) + if err != nil { + panic(err) + } + + return t +} + // positionalEmbedding generates rotary position embeddings for attention mechanisms // This implements rotary embeddings using spatial merging patterns for grid-based // vision transformers @@ -245,8 +301,6 @@ func (m *VisionModel) positionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor // Use position indices to look up corresponding frequency values positionalEmbedding := freqs.Rows(ctx, pos) positionalEmbedding = positionalEmbedding.Reshape(ctx, positionalEmbedding.Dim(0)*2, positionalEmbedding.Dim(1)/2) - positionalEmbedding = positionalEmbedding.Concat(ctx, positionalEmbedding, 0) - return positionalEmbedding } @@ -271,6 +325,9 @@ func newVisionModel(c fs.Config) *VisionModel { eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6), ropeTheta: ropeTheta, outHiddenSize: outHiddenSize, + spatialMergeSize: int(c.Uint("vision.spatial_merge_size", 2)), + spatialPatchSize: int(c.Uint("vision.spatial_patch_size", 2)), + windowSize: int(c.Uint("vision.window_size", 112)), }, } }