diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go index e8375c0b5..762ec614e 100644 --- a/model/models/qwen25vl/model_vision.go +++ b/model/models/qwen25vl/model_vision.go @@ -1,7 +1,6 @@ package qwen25vl import ( - "fmt" "math" "github.com/ollama/ollama/fs" @@ -122,45 +121,32 @@ func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, patchSi shape := pixelValues.Shape() numChannels := 3 temporalPatchSize := 2 + numPatches := shape[1] // TODO: check this + // patch size from args embedDim := 1280 - numPatches := shape[1] / temporalPatchSize - // Split the input tensor into two temporal slices and process each separately - // First temporal slice (frame 0) - slice0 := pixelValues.View(ctx, 0, patchSize*patchSize*numChannels, 0, numPatches, 0, 1).Contiguous(ctx) - reshaped0 := slice0.Reshape(ctx, - patchSize, // height - patchSize, // width - numChannels, // channels - numPatches) // batch + // Reshape the input tensor to match the expected dimensions + pixelValues = pixelValues.Reshape(ctx, patchSize*patchSize, temporalPatchSize, numChannels, numPatches) - // Second temporal slice (frame 1) - slice1 := pixelValues.View(ctx, 0, patchSize*patchSize*numChannels, 0, numPatches, 1, 1).Contiguous(ctx) - reshaped1 := slice1.Reshape(ctx, - patchSize, // height - patchSize, // width - numChannels, // channels - numPatches) // batch + // Permute the tensor to bring the temporal dimension to the front + pixelValues = pixelValues.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) - // Apply the appropriate convolution to each temporal slice - // PatchConv0 corresponds to weights for temporal frame 0 - // PatchConv1 corresponds to weights for temporal frame 1 - s0, s1 := patchSize, patchSize // Use full stride as in original + in0 := pixelValues.View(ctx, 0, 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx) + in0 = in0.Reshape(ctx, patchSize, patchSize, numChannels, numPatches) + in1 := pixelValues.View(ctx, pixelValues.Stride(0), 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx) + in1 = in1.Reshape(ctx, patchSize, patchSize, numChannels, numPatches) + + s0, s1 := patchSize, patchSize // Use full stride p0, p1 := 0, 0 // padding d0, d1 := 1, 1 // dilation - - output0 := pe.PatchConv0.Forward(ctx, reshaped0, s0, s1, p0, p1, d0, d1) - output1 := pe.PatchConv1.Forward(ctx, reshaped1, s0, s1, p0, p1, d0, d1) + out0 := pe.PatchConv0.Forward(ctx, in0, s0, s1, p0, p1, d0, d1) + out1 := pe.PatchConv1.Forward(ctx, in1, s0, s1, p0, p1, d0, d1) // Add the outputs from the two temporal convolutions - combined := output0.Add(ctx, output1) + out := out0.Add(ctx, out1) - // Reshape to required output dimensions - result := combined.Reshape(ctx, embedDim, numPatches) - - fmt.Println(ml.Dump(ctx, result)) - - return result + // Reshape the output tensor to match the expected dimensions + return out.Reshape(ctx, embedDim, numPatches) } // VisionPatchMerger implements patch merging for the Qwen vision model