update exported functions for tests
This commit is contained in:
parent
919b3d6e21
commit
9876c8453a
@ -36,19 +36,15 @@ func New(c fs.Config) (model.Model, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
||||
if len(m.VisionModel.Layers) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *Grid, error) {
|
||||
image, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
f32s, grid, err := m.ImageProcessor.ProcessImage(image)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Calculate tensor dimensions
|
||||
@ -58,10 +54,23 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
||||
|
||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create tensor from image: %w", err)
|
||||
return nil, nil, fmt.Errorf("failed to create tensor from image: %w", err)
|
||||
}
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues, grid)
|
||||
return pixelValues, grid, nil
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
||||
if len(m.VisionModel.Layers) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
pixels, grid, err := m.PixelValues(ctx, multimodalData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixels, grid)
|
||||
return visionOutputs, nil
|
||||
}
|
||||
|
||||
|
@ -217,9 +217,9 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
||||
// Extract patch embeddings
|
||||
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionModelOptions)
|
||||
|
||||
positionEmbedding := m.positionalEmbedding(ctx, grid)
|
||||
positionEmbedding := m.PositionalEmbedding(ctx, grid)
|
||||
|
||||
windowIndex, bounds := m.windowIndex(ctx, grid)
|
||||
windowIndex, bounds := m.WindowIndex(ctx, grid)
|
||||
|
||||
spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize
|
||||
|
||||
@ -250,7 +250,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
||||
return hiddenStates.Rows(ctx, reverseWindowIndex)
|
||||
}
|
||||
|
||||
// windowIndex divides the grid into windows and returns:
|
||||
// WindowIndex divides the grid into windows and returns:
|
||||
// 1. A tensor containing flattened indices of all grid points organized by windows
|
||||
// 2. A slice of boundaries that mark where each window's data begins and ends
|
||||
// in the flattened representation, scaled by spatialMergeSize squared
|
||||
@ -258,7 +258,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
||||
// The boundaries slice always starts with 0 and contains cumulative ending
|
||||
// positions for each window, allowing downstream processing to identify
|
||||
// window boundaries in the tensor data.
|
||||
func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) {
|
||||
func (m *VisionModel) WindowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) {
|
||||
vitMergerWindowSize := m.windowSize / m.spatialMergeSize / m.patchSize
|
||||
|
||||
llmGridH := grid.Height / m.spatialMergeSize
|
||||
@ -307,8 +307,8 @@ func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int)
|
||||
return t, bounds
|
||||
}
|
||||
|
||||
// positionalEmbedding generates rotary position embeddings for attention mechanisms
|
||||
func (m *VisionModel) positionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor {
|
||||
// PositionalEmbedding generates rotary position embeddings for attention mechanisms
|
||||
func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor {
|
||||
dim := m.headDim / 2
|
||||
freq := dim / 2
|
||||
theta := float64(m.ropeTheta)
|
||||
|
Loading…
x
Reference in New Issue
Block a user