simplify patch creation
This commit is contained in:
parent
39ee6d2bd0
commit
7c555d394c
@ -124,39 +124,33 @@ func (p *ImageProcessor) createPatches(pixels []float32, height, width int, grid
|
|||||||
numPatches := grid.Temporal * grid.Height * grid.Width
|
numPatches := grid.Temporal * grid.Height * grid.Width
|
||||||
patchDim := channels * temporalPatchSize * patchSize * patchSize
|
patchDim := channels * temporalPatchSize * patchSize * patchSize
|
||||||
|
|
||||||
// Create output tensor
|
|
||||||
result := make([]float32, numPatches*patchDim)
|
result := make([]float32, numPatches*patchDim)
|
||||||
|
|
||||||
// Instead of the complex 9D reshape+transpose, directly extract patches
|
|
||||||
// in the format expected by the forward pass
|
|
||||||
patchIndex := 0
|
patchIndex := 0
|
||||||
|
|
||||||
|
// Single temporal frame handling (copies to all frames)
|
||||||
for range grid.Temporal {
|
for range grid.Temporal {
|
||||||
// For each patch in the grid
|
|
||||||
for h := 0; h < grid.Height; h += mergeSize {
|
for h := 0; h < grid.Height; h += mergeSize {
|
||||||
for w := 0; w < grid.Width; w += mergeSize {
|
for w := 0; w < grid.Width; w += mergeSize {
|
||||||
// Handle the 2x2 merged patches
|
// Handle the 2x2 merged patches
|
||||||
for mh := range mergeSize {
|
for mh := range mergeSize {
|
||||||
for mw := range mergeSize {
|
for mw := range mergeSize {
|
||||||
// For each pixel in the patch
|
baseOffset := patchIndex * patchDim
|
||||||
for py := range patchSize {
|
|
||||||
for px := range patchSize {
|
|
||||||
// Calculate source coordinates
|
|
||||||
y := (h+mh)*patchSize + py
|
|
||||||
x := (w+mw)*patchSize + px
|
|
||||||
|
|
||||||
// For each channel
|
// Extract patch data for first temporal frame
|
||||||
for c := range channels {
|
for c := range channels {
|
||||||
// Channel-first format (CHW)
|
channelOffset := baseOffset + (c * temporalPatchSize * patchSize * patchSize)
|
||||||
|
|
||||||
|
for py := range patchSize {
|
||||||
|
for px := range patchSize {
|
||||||
|
// Calculate source pixel coordinates
|
||||||
|
y := (h+mh)*patchSize + py
|
||||||
|
x := (w+mw)*patchSize + px
|
||||||
|
|
||||||
|
// Source index in input tensor (CHW format)
|
||||||
srcIdx := c*height*width + y*width + x
|
srcIdx := c*height*width + y*width + x
|
||||||
|
|
||||||
// Calculate destination index based on the expected layout
|
// Destination index in first temporal frame
|
||||||
// This is the key part that matches what the model expects
|
dstIdx := channelOffset + (py * patchSize) + px
|
||||||
dstIdx := patchIndex*patchDim +
|
|
||||||
(c * temporalPatchSize * patchSize * patchSize) +
|
|
||||||
(0 * patchSize * patchSize) + // temporal dim
|
|
||||||
(py * patchSize) +
|
|
||||||
px
|
|
||||||
|
|
||||||
if srcIdx < len(pixels) && dstIdx < len(result) {
|
if srcIdx < len(pixels) && dstIdx < len(result) {
|
||||||
result[dstIdx] = pixels[srcIdx]
|
result[dstIdx] = pixels[srcIdx]
|
||||||
@ -165,27 +159,18 @@ func (p *ImageProcessor) createPatches(pixels []float32, height, width int, grid
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle temporal dimension padding (if needed)
|
// Copy first temporal frame to all other frames
|
||||||
for tp := 1; tp < temporalPatchSize; tp++ {
|
if temporalPatchSize > 1 {
|
||||||
for py := range patchSize {
|
for c := range channels {
|
||||||
for px := range patchSize {
|
channelOffset := baseOffset + (c * temporalPatchSize * patchSize * patchSize)
|
||||||
for c := range channels {
|
firstFrameOffset := channelOffset
|
||||||
srcIdx := patchIndex*patchDim +
|
frameSize := patchSize * patchSize
|
||||||
(c * temporalPatchSize * patchSize * patchSize) +
|
|
||||||
(0 * patchSize * patchSize) + // first temporal frame
|
|
||||||
(py * patchSize) +
|
|
||||||
px
|
|
||||||
|
|
||||||
dstIdx := patchIndex*patchDim +
|
// Copy first frame to all other frames
|
||||||
(c * temporalPatchSize * patchSize * patchSize) +
|
for tp := 1; tp < temporalPatchSize; tp++ {
|
||||||
(tp * patchSize * patchSize) + // current temporal frame
|
currentFrameOffset := channelOffset + (tp * frameSize)
|
||||||
(py * patchSize) +
|
copy(result[currentFrameOffset:currentFrameOffset+frameSize],
|
||||||
px
|
result[firstFrameOffset:firstFrameOffset+frameSize])
|
||||||
|
|
||||||
if srcIdx < len(result) && dstIdx < len(result) {
|
|
||||||
result[dstIdx] = result[srcIdx] // Copy from first frame
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user