simplify patch creation

This commit is contained in:
Bruce MacDonald 2025-05-01 14:36:07 -07:00
parent 39ee6d2bd0
commit 7c555d394c

View File

@ -124,39 +124,33 @@ func (p *ImageProcessor) createPatches(pixels []float32, height, width int, grid
numPatches := grid.Temporal * grid.Height * grid.Width numPatches := grid.Temporal * grid.Height * grid.Width
patchDim := channels * temporalPatchSize * patchSize * patchSize patchDim := channels * temporalPatchSize * patchSize * patchSize
// Create output tensor
result := make([]float32, numPatches*patchDim) result := make([]float32, numPatches*patchDim)
// Instead of the complex 9D reshape+transpose, directly extract patches
// in the format expected by the forward pass
patchIndex := 0 patchIndex := 0
// Single temporal frame handling (copies to all frames)
for range grid.Temporal { for range grid.Temporal {
// For each patch in the grid
for h := 0; h < grid.Height; h += mergeSize { for h := 0; h < grid.Height; h += mergeSize {
for w := 0; w < grid.Width; w += mergeSize { for w := 0; w < grid.Width; w += mergeSize {
// Handle the 2x2 merged patches // Handle the 2x2 merged patches
for mh := range mergeSize { for mh := range mergeSize {
for mw := range mergeSize { for mw := range mergeSize {
// For each pixel in the patch baseOffset := patchIndex * patchDim
for py := range patchSize {
for px := range patchSize {
// Calculate source coordinates
y := (h+mh)*patchSize + py
x := (w+mw)*patchSize + px
// For each channel // Extract patch data for first temporal frame
for c := range channels { for c := range channels {
// Channel-first format (CHW) channelOffset := baseOffset + (c * temporalPatchSize * patchSize * patchSize)
for py := range patchSize {
for px := range patchSize {
// Calculate source pixel coordinates
y := (h+mh)*patchSize + py
x := (w+mw)*patchSize + px
// Source index in input tensor (CHW format)
srcIdx := c*height*width + y*width + x srcIdx := c*height*width + y*width + x
// Calculate destination index based on the expected layout // Destination index in first temporal frame
// This is the key part that matches what the model expects dstIdx := channelOffset + (py * patchSize) + px
dstIdx := patchIndex*patchDim +
(c * temporalPatchSize * patchSize * patchSize) +
(0 * patchSize * patchSize) + // temporal dim
(py * patchSize) +
px
if srcIdx < len(pixels) && dstIdx < len(result) { if srcIdx < len(pixels) && dstIdx < len(result) {
result[dstIdx] = pixels[srcIdx] result[dstIdx] = pixels[srcIdx]
@ -165,27 +159,18 @@ func (p *ImageProcessor) createPatches(pixels []float32, height, width int, grid
} }
} }
// Handle temporal dimension padding (if needed) // Copy first temporal frame to all other frames
for tp := 1; tp < temporalPatchSize; tp++ { if temporalPatchSize > 1 {
for py := range patchSize { for c := range channels {
for px := range patchSize { channelOffset := baseOffset + (c * temporalPatchSize * patchSize * patchSize)
for c := range channels { firstFrameOffset := channelOffset
srcIdx := patchIndex*patchDim + frameSize := patchSize * patchSize
(c * temporalPatchSize * patchSize * patchSize) +
(0 * patchSize * patchSize) + // first temporal frame
(py * patchSize) +
px
dstIdx := patchIndex*patchDim + // Copy first frame to all other frames
(c * temporalPatchSize * patchSize * patchSize) + for tp := 1; tp < temporalPatchSize; tp++ {
(tp * patchSize * patchSize) + // current temporal frame currentFrameOffset := channelOffset + (tp * frameSize)
(py * patchSize) + copy(result[currentFrameOffset:currentFrameOffset+frameSize],
px result[firstFrameOffset:firstFrameOffset+frameSize])
if srcIdx < len(result) && dstIdx < len(result) {
result[dstIdx] = result[srcIdx] // Copy from first frame
}
}
} }
} }
} }