update exported functions for tests
This commit is contained in:
parent
919b3d6e21
commit
9876c8453a
@ -36,19 +36,15 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *Grid, error) {
|
||||||
if len(m.VisionModel.Layers) == 0 {
|
|
||||||
return nil, model.ErrNoVisionModel
|
|
||||||
}
|
|
||||||
|
|
||||||
image, _, err := image.Decode(bytes.NewReader(multimodalData))
|
image, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
f32s, grid, err := m.ImageProcessor.ProcessImage(image)
|
f32s, grid, err := m.ImageProcessor.ProcessImage(image)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate tensor dimensions
|
// Calculate tensor dimensions
|
||||||
@ -58,10 +54,23 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
|||||||
|
|
||||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches)
|
pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create tensor from image: %w", err)
|
return nil, nil, fmt.Errorf("failed to create tensor from image: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues, grid)
|
return pixelValues, grid, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
||||||
|
if len(m.VisionModel.Layers) == 0 {
|
||||||
|
return nil, model.ErrNoVisionModel
|
||||||
|
}
|
||||||
|
|
||||||
|
pixels, grid, err := m.PixelValues(ctx, multimodalData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
visionOutputs := m.VisionModel.Forward(ctx, pixels, grid)
|
||||||
return visionOutputs, nil
|
return visionOutputs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -217,9 +217,9 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
|||||||
// Extract patch embeddings
|
// Extract patch embeddings
|
||||||
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionModelOptions)
|
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionModelOptions)
|
||||||
|
|
||||||
positionEmbedding := m.positionalEmbedding(ctx, grid)
|
positionEmbedding := m.PositionalEmbedding(ctx, grid)
|
||||||
|
|
||||||
windowIndex, bounds := m.windowIndex(ctx, grid)
|
windowIndex, bounds := m.WindowIndex(ctx, grid)
|
||||||
|
|
||||||
spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize
|
spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize
|
||||||
|
|
||||||
@ -250,7 +250,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
|||||||
return hiddenStates.Rows(ctx, reverseWindowIndex)
|
return hiddenStates.Rows(ctx, reverseWindowIndex)
|
||||||
}
|
}
|
||||||
|
|
||||||
// windowIndex divides the grid into windows and returns:
|
// WindowIndex divides the grid into windows and returns:
|
||||||
// 1. A tensor containing flattened indices of all grid points organized by windows
|
// 1. A tensor containing flattened indices of all grid points organized by windows
|
||||||
// 2. A slice of boundaries that mark where each window's data begins and ends
|
// 2. A slice of boundaries that mark where each window's data begins and ends
|
||||||
// in the flattened representation, scaled by spatialMergeSize squared
|
// in the flattened representation, scaled by spatialMergeSize squared
|
||||||
@ -258,7 +258,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
|||||||
// The boundaries slice always starts with 0 and contains cumulative ending
|
// The boundaries slice always starts with 0 and contains cumulative ending
|
||||||
// positions for each window, allowing downstream processing to identify
|
// positions for each window, allowing downstream processing to identify
|
||||||
// window boundaries in the tensor data.
|
// window boundaries in the tensor data.
|
||||||
func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) {
|
func (m *VisionModel) WindowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) {
|
||||||
vitMergerWindowSize := m.windowSize / m.spatialMergeSize / m.patchSize
|
vitMergerWindowSize := m.windowSize / m.spatialMergeSize / m.patchSize
|
||||||
|
|
||||||
llmGridH := grid.Height / m.spatialMergeSize
|
llmGridH := grid.Height / m.spatialMergeSize
|
||||||
@ -307,8 +307,8 @@ func (m *VisionModel) windowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int)
|
|||||||
return t, bounds
|
return t, bounds
|
||||||
}
|
}
|
||||||
|
|
||||||
// positionalEmbedding generates rotary position embeddings for attention mechanisms
|
// PositionalEmbedding generates rotary position embeddings for attention mechanisms
|
||||||
func (m *VisionModel) positionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor {
|
func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor {
|
||||||
dim := m.headDim / 2
|
dim := m.headDim / 2
|
||||||
freq := dim / 2
|
freq := dim / 2
|
||||||
theta := float64(m.ropeTheta)
|
theta := float64(m.ropeTheta)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user