window index
This commit is contained in:
parent
ee869f35e4
commit
fb3c16f2a2
@ -18,8 +18,9 @@ type qwen25VLModel struct {
|
|||||||
InChannels uint32 `json:"in_chans"`
|
InChannels uint32 `json:"in_chans"`
|
||||||
NumHeads uint32 `json:"num_heads"`
|
NumHeads uint32 `json:"num_heads"`
|
||||||
PatchSize uint32 `json:"patch_size"`
|
PatchSize uint32 `json:"patch_size"`
|
||||||
SpatialMergeSize uint32 `json:"spatial_merge_size"` // TODO: is this set?
|
SpatialMergeSize uint32 `json:"spatial_merge_size"`
|
||||||
SpatialPatchSize uint32 `json:"spatial_patch_size"`
|
SpatialPatchSize uint32 `json:"spatial_patch_size"`
|
||||||
|
WindowSize uint32 `json:"window_size"`
|
||||||
RopeTheta float32 `json:"rope_theta"`
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
} `json:"vision_config"`
|
} `json:"vision_config"`
|
||||||
}
|
}
|
||||||
|
@ -100,6 +100,9 @@ type VisionModelOptions struct {
|
|||||||
eps float32
|
eps float32
|
||||||
ropeTheta float32
|
ropeTheta float32
|
||||||
outHiddenSize int
|
outHiddenSize int
|
||||||
|
spatialMergeSize int
|
||||||
|
spatialPatchSize int
|
||||||
|
windowSize int
|
||||||
}
|
}
|
||||||
|
|
||||||
type PatchEmbedding struct {
|
type PatchEmbedding struct {
|
||||||
@ -184,6 +187,20 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
|||||||
)
|
)
|
||||||
|
|
||||||
positionEmbedding := m.positionalEmbedding(ctx, grid)
|
positionEmbedding := m.positionalEmbedding(ctx, grid)
|
||||||
|
|
||||||
|
windowIndex := m.windowIndex(ctx, grid)
|
||||||
|
|
||||||
|
spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize
|
||||||
|
|
||||||
|
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*spatialMergeUnit, hiddenStates.Dim(1)/spatialMergeUnit)
|
||||||
|
hiddenStates = hiddenStates.Rows(ctx, windowIndex)
|
||||||
|
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)/spatialMergeUnit, hiddenStates.Dim(1)*spatialMergeUnit)
|
||||||
|
|
||||||
|
positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)*spatialMergeUnit, positionEmbedding.Dim(1)/spatialMergeUnit)
|
||||||
|
positionEmbedding = positionEmbedding.Rows(ctx, windowIndex)
|
||||||
|
positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)/spatialMergeUnit, positionEmbedding.Dim(1)*spatialMergeUnit)
|
||||||
|
positionEmbedding = positionEmbedding.Concat(ctx, positionEmbedding, 0)
|
||||||
|
|
||||||
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
|
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
|
||||||
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
|
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
|
||||||
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
|
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
|
||||||
@ -196,6 +213,45 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
|||||||
return m.PatchMerger.Forward(ctx, hiddenStates, m.eps)
|
return m.PatchMerger.Forward(ctx, hiddenStates, m.eps)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) ml.Tensor {
|
||||||
|
vitMergerWindowSize := m.windowSize / m.spatialMergeSize / m.patchSize
|
||||||
|
|
||||||
|
llmGridH := grid.Height / m.spatialMergeSize
|
||||||
|
llmGridW := grid.Width / m.spatialMergeSize
|
||||||
|
|
||||||
|
// Calculate window parameters
|
||||||
|
numWindowsH := int(math.Ceil(float64(llmGridH) / float64(vitMergerWindowSize)))
|
||||||
|
numWindowsW := int(math.Ceil(float64(llmGridW) / float64(vitMergerWindowSize)))
|
||||||
|
|
||||||
|
// Initialize index_new slice
|
||||||
|
var index []int32
|
||||||
|
|
||||||
|
// Process each window without padding
|
||||||
|
for wh := range numWindowsH {
|
||||||
|
for ww := range numWindowsW {
|
||||||
|
// Calculate window boundaries
|
||||||
|
hStart := wh * vitMergerWindowSize
|
||||||
|
wStart := ww * vitMergerWindowSize
|
||||||
|
hEnd := min(hStart+vitMergerWindowSize, llmGridH)
|
||||||
|
wEnd := min(wStart+vitMergerWindowSize, llmGridW)
|
||||||
|
|
||||||
|
// Collect indices for this window
|
||||||
|
for h := hStart; h < hEnd; h++ {
|
||||||
|
for w := wStart; w < wEnd; w++ {
|
||||||
|
index = append(index, int32(h*llmGridW+w))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t, err := ctx.Input().FromIntSlice(index, len(index))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
// positionalEmbedding generates rotary position embeddings for attention mechanisms
|
// positionalEmbedding generates rotary position embeddings for attention mechanisms
|
||||||
// This implements rotary embeddings using spatial merging patterns for grid-based
|
// This implements rotary embeddings using spatial merging patterns for grid-based
|
||||||
// vision transformers
|
// vision transformers
|
||||||
@ -245,8 +301,6 @@ func (m *VisionModel) positionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor
|
|||||||
// Use position indices to look up corresponding frequency values
|
// Use position indices to look up corresponding frequency values
|
||||||
positionalEmbedding := freqs.Rows(ctx, pos)
|
positionalEmbedding := freqs.Rows(ctx, pos)
|
||||||
positionalEmbedding = positionalEmbedding.Reshape(ctx, positionalEmbedding.Dim(0)*2, positionalEmbedding.Dim(1)/2)
|
positionalEmbedding = positionalEmbedding.Reshape(ctx, positionalEmbedding.Dim(0)*2, positionalEmbedding.Dim(1)/2)
|
||||||
positionalEmbedding = positionalEmbedding.Concat(ctx, positionalEmbedding, 0)
|
|
||||||
|
|
||||||
return positionalEmbedding
|
return positionalEmbedding
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -271,6 +325,9 @@ func newVisionModel(c fs.Config) *VisionModel {
|
|||||||
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6),
|
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6),
|
||||||
ropeTheta: ropeTheta,
|
ropeTheta: ropeTheta,
|
||||||
outHiddenSize: outHiddenSize,
|
outHiddenSize: outHiddenSize,
|
||||||
|
spatialMergeSize: int(c.Uint("vision.spatial_merge_size", 2)),
|
||||||
|
spatialPatchSize: int(c.Uint("vision.spatial_patch_size", 2)),
|
||||||
|
windowSize: int(c.Uint("vision.window_size", 112)),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user