update exported functions for tests

This commit is contained in:
Bruce MacDonald 2025-05-07 15:35:29 -07:00
parent 919b3d6e21
commit 9876c8453a
2 changed files with 24 additions and 15 deletions

View File

@ -36,19 +36,15 @@ func New(c fs.Config) (model.Model, error) {
return m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *Grid, error) {
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
return nil, nil, err
}
f32s, grid, err := m.ImageProcessor.ProcessImage(image)
if err != nil {
return nil, err
return nil, nil, err
}
// Calculate tensor dimensions
@ -58,10 +54,23 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches)
if err != nil {
return nil, fmt.Errorf("failed to create tensor from image: %w", err)
return nil, nil, fmt.Errorf("failed to create tensor from image: %w", err)
}
visionOutputs := m.VisionModel.Forward(ctx, pixelValues, grid)
return pixelValues, grid, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
pixels, grid, err := m.PixelValues(ctx, multimodalData)
if err != nil {
return nil, err
}
visionOutputs := m.VisionModel.Forward(ctx, pixels, grid)
return visionOutputs, nil
}

View File

@ -217,9 +217,9 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
// Extract patch embeddings
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionModelOptions)
positionEmbedding := m.positionalEmbedding(ctx, grid)
positionEmbedding := m.PositionalEmbedding(ctx, grid)
windowIndex, bounds := m.windowIndex(ctx, grid)
windowIndex, bounds := m.WindowIndex(ctx, grid)
spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize
@ -250,7 +250,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
return hiddenStates.Rows(ctx, reverseWindowIndex)
}
// windowIndex divides the grid into windows and returns:
// WindowIndex divides the grid into windows and returns:
// 1. A tensor containing flattened indices of all grid points organized by windows
// 2. A slice of boundaries that mark where each window's data begins and ends
// in the flattened representation, scaled by spatialMergeSize squared
@ -258,7 +258,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
// The boundaries slice always starts with 0 and contains cumulative ending
// positions for each window, allowing downstream processing to identify
// window boundaries in the tensor data.
func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) {
func (m *VisionModel) WindowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) {
vitMergerWindowSize := m.windowSize / m.spatialMergeSize / m.patchSize
llmGridH := grid.Height / m.spatialMergeSize
@ -307,8 +307,8 @@ func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int)
return t, bounds
}
// positionalEmbedding generates rotary position embeddings for attention mechanisms
func (m *VisionModel) positionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor {
// PositionalEmbedding generates rotary position embeddings for attention mechanisms
func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor {
dim := m.headDim / 2
freq := dim / 2
theta := float64(m.ropeTheta)