window index
This commit is contained in:
parent
ee869f35e4
commit
fb3c16f2a2
@ -18,8 +18,9 @@ type qwen25VLModel struct {
|
||||
InChannels uint32 `json:"in_chans"`
|
||||
NumHeads uint32 `json:"num_heads"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
SpatialMergeSize uint32 `json:"spatial_merge_size"` // TODO: is this set?
|
||||
SpatialMergeSize uint32 `json:"spatial_merge_size"`
|
||||
SpatialPatchSize uint32 `json:"spatial_patch_size"`
|
||||
WindowSize uint32 `json:"window_size"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
@ -100,6 +100,9 @@ type VisionModelOptions struct {
|
||||
eps float32
|
||||
ropeTheta float32
|
||||
outHiddenSize int
|
||||
spatialMergeSize int
|
||||
spatialPatchSize int
|
||||
windowSize int
|
||||
}
|
||||
|
||||
type PatchEmbedding struct {
|
||||
@ -184,6 +187,20 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
||||
)
|
||||
|
||||
positionEmbedding := m.positionalEmbedding(ctx, grid)
|
||||
|
||||
windowIndex := m.windowIndex(ctx, grid)
|
||||
|
||||
spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize
|
||||
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*spatialMergeUnit, hiddenStates.Dim(1)/spatialMergeUnit)
|
||||
hiddenStates = hiddenStates.Rows(ctx, windowIndex)
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)/spatialMergeUnit, hiddenStates.Dim(1)*spatialMergeUnit)
|
||||
|
||||
positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)*spatialMergeUnit, positionEmbedding.Dim(1)/spatialMergeUnit)
|
||||
positionEmbedding = positionEmbedding.Rows(ctx, windowIndex)
|
||||
positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)/spatialMergeUnit, positionEmbedding.Dim(1)*spatialMergeUnit)
|
||||
positionEmbedding = positionEmbedding.Concat(ctx, positionEmbedding, 0)
|
||||
|
||||
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
|
||||
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
|
||||
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
|
||||
@ -196,6 +213,45 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
||||
return m.PatchMerger.Forward(ctx, hiddenStates, m.eps)
|
||||
}
|
||||
|
||||
func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor {
|
||||
vitMergerWindowSize := m.windowSize / m.spatialMergeSize / m.patchSize
|
||||
|
||||
llmGridH := grid.Height / m.spatialMergeSize
|
||||
llmGridW := grid.Width / m.spatialMergeSize
|
||||
|
||||
// Calculate window parameters
|
||||
numWindowsH := int(math.Ceil(float64(llmGridH) / float64(vitMergerWindowSize)))
|
||||
numWindowsW := int(math.Ceil(float64(llmGridW) / float64(vitMergerWindowSize)))
|
||||
|
||||
// Initialize index_new slice
|
||||
var index []int32
|
||||
|
||||
// Process each window without padding
|
||||
for wh := range numWindowsH {
|
||||
for ww := range numWindowsW {
|
||||
// Calculate window boundaries
|
||||
hStart := wh * vitMergerWindowSize
|
||||
wStart := ww * vitMergerWindowSize
|
||||
hEnd := min(hStart+vitMergerWindowSize, llmGridH)
|
||||
wEnd := min(wStart+vitMergerWindowSize, llmGridW)
|
||||
|
||||
// Collect indices for this window
|
||||
for h := hStart; h < hEnd; h++ {
|
||||
for w := wStart; w < wEnd; w++ {
|
||||
index = append(index, int32(h*llmGridW+w))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t, err := ctx.Input().FromIntSlice(index, len(index))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
// positionalEmbedding generates rotary position embeddings for attention mechanisms
|
||||
// This implements rotary embeddings using spatial merging patterns for grid-based
|
||||
// vision transformers
|
||||
@ -245,8 +301,6 @@ func (m *VisionModel) positionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor
|
||||
// Use position indices to look up corresponding frequency values
|
||||
positionalEmbedding := freqs.Rows(ctx, pos)
|
||||
positionalEmbedding = positionalEmbedding.Reshape(ctx, positionalEmbedding.Dim(0)*2, positionalEmbedding.Dim(1)/2)
|
||||
positionalEmbedding = positionalEmbedding.Concat(ctx, positionalEmbedding, 0)
|
||||
|
||||
return positionalEmbedding
|
||||
}
|
||||
|
||||
@ -271,6 +325,9 @@ func newVisionModel(c fs.Config) *VisionModel {
|
||||
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6),
|
||||
ropeTheta: ropeTheta,
|
||||
outHiddenSize: outHiddenSize,
|
||||
spatialMergeSize: int(c.Uint("vision.spatial_merge_size", 2)),
|
||||
spatialPatchSize: int(c.Uint("vision.spatial_patch_size", 2)),
|
||||
windowSize: int(c.Uint("vision.window_size", 112)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user