From 9876c8453aecb458dec9f2d9253e0bf92576e7bc Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Wed, 7 May 2025 15:35:29 -0700 Subject: [PATCH] update exported functions for tests --- model/models/qwen25vl/model.go | 27 ++++++++++++++++++--------- model/models/qwen25vl/model_vision.go | 12 ++++++------ 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go index 552c38cc0..546e68b13 100644 --- a/model/models/qwen25vl/model.go +++ b/model/models/qwen25vl/model.go @@ -36,19 +36,15 @@ func New(c fs.Config) (model.Model, error) { return m, nil } -func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { - if len(m.VisionModel.Layers) == 0 { - return nil, model.ErrNoVisionModel - } - +func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *Grid, error) { image, _, err := image.Decode(bytes.NewReader(multimodalData)) if err != nil { - return nil, err + return nil, nil, err } f32s, grid, err := m.ImageProcessor.ProcessImage(image) if err != nil { - return nil, err + return nil, nil, err } // Calculate tensor dimensions @@ -58,10 +54,23 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches) if err != nil { - return nil, fmt.Errorf("failed to create tensor from image: %w", err) + return nil, nil, fmt.Errorf("failed to create tensor from image: %w", err) } - visionOutputs := m.VisionModel.Forward(ctx, pixelValues, grid) + return pixelValues, grid, nil +} + +func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { + if len(m.VisionModel.Layers) == 0 { + return nil, model.ErrNoVisionModel + } + + pixels, grid, err := m.PixelValues(ctx, multimodalData) + if err != nil { + return nil, err + } + + visionOutputs := m.VisionModel.Forward(ctx, pixels, grid) return visionOutputs, nil } diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go index 2c839b3bd..1574c0bc7 100644 --- a/model/models/qwen25vl/model_vision.go +++ b/model/models/qwen25vl/model_vision.go @@ -217,9 +217,9 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) // Extract patch embeddings hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionModelOptions) - positionEmbedding := m.positionalEmbedding(ctx, grid) + positionEmbedding := m.PositionalEmbedding(ctx, grid) - windowIndex, bounds := m.windowIndex(ctx, grid) + windowIndex, bounds := m.WindowIndex(ctx, grid) spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize @@ -250,7 +250,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) return hiddenStates.Rows(ctx, reverseWindowIndex) } -// windowIndex divides the grid into windows and returns: +// WindowIndex divides the grid into windows and returns: // 1. A tensor containing flattened indices of all grid points organized by windows // 2. A slice of boundaries that mark where each window's data begins and ends // in the flattened representation, scaled by spatialMergeSize squared @@ -258,7 +258,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) // The boundaries slice always starts with 0 and contains cumulative ending // positions for each window, allowing downstream processing to identify // window boundaries in the tensor data. -func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) { +func (m *VisionModel) WindowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) { vitMergerWindowSize := m.windowSize / m.spatialMergeSize / m.patchSize llmGridH := grid.Height / m.spatialMergeSize @@ -307,8 +307,8 @@ func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) return t, bounds } -// positionalEmbedding generates rotary position embeddings for attention mechanisms -func (m *VisionModel) positionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor { +// PositionalEmbedding generates rotary position embeddings for attention mechanisms +func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor { dim := m.headDim / 2 freq := dim / 2 theta := float64(m.ropeTheta)