From 448fc4cd2a989d578d72b833e486792f05eec0d1 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Wed, 12 Mar 2025 00:45:41 -0400 Subject: [PATCH] sample: use container/heap for top_k --- sample/transforms.go | 81 ++++++++++++++++++++------------------------ 1 file changed, 36 insertions(+), 45 deletions(-) diff --git a/sample/transforms.go b/sample/transforms.go index ab62455f3..82dbb1fec 100644 --- a/sample/transforms.go +++ b/sample/transforms.go @@ -1,10 +1,30 @@ package sample import ( + "container/heap" "math" "slices" ) +// tokenHeap implements heap.Interface and holds tokens as a min-heap to track k largest elements +type tokenHeap []token + +func (h tokenHeap) Len() int { return len(h) } +func (h tokenHeap) Less(i, j int) bool { return h[i].value < h[j].value } // Use < for min-heap to track largest elements +func (h tokenHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } + +func (h *tokenHeap) Push(x any) { + *h = append(*h, x.(token)) +} + +func (h *tokenHeap) Pop() any { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} + // temperature applies scaling and softmax to the logits func temperature(ts []token, temp float32) []token { // Find max logit for numerical stability @@ -31,62 +51,33 @@ func temperature(ts []token, temp float32) []token { return ts } -// siftDown maintains a min-heap property by recursively moving larger elements down the heap. -// -// The heap is represented as an array where for any node at index i: -// - Left child is at index 2i + 1 -// - Right child is at index 2i + 2 -// - Parent is at index (i-1)/2 -// -// The function compares a node with its children and: -// 1. Finds the smallest value between the node and its children -// 2. If the node is not the smallest, swaps it with its smallest child -// 3. Continues this process down the affected path until the min-heap property is restored -func siftDown(data []token, start, end int) { - root := start - for { - child := 2*root + 1 - if child >= end { - break - } - // Find smaller child (we want min heap) - if child+1 < end && data[child+1].value < data[child].value { - child++ - } - // Exit if root is already smaller than children - if data[root].value <= data[child].value { - break - } - // Swap with smaller child and continue - data[root], data[child] = data[child], data[root] - root = child - } -} - // topK limits the number of tokens considered to the k highest logits func topK(ts []token, k int) []token { if k >= len(ts) { + sortLogits(ts) return ts } - // Heapify + siftDown - O(nlog(k)) - // Build min-heap of first k elements - heap := ts[:k] - for i := k/2 - 1; i >= 0; i-- { - siftDown(heap, i, k) - } - // Process remaining elements - if larger than heap root, replace root + // Initialize min-heap with first k elements + h := make(tokenHeap, k) + copy(h, ts[:k]) + heap.Init(&h) + + // Process remaining elements for i := k; i < len(ts); i++ { - if ts[i].value > heap[0].value { - heap[0] = ts[i] - siftDown(heap, 0, k) + if ts[i].value > h[0].value { + heap.Pop(&h) + heap.Push(&h, ts[i]) } } - slices.Reverse(heap) + // Convert heap to sorted slice in descending order + result := make([]token, k) + for i := k - 1; i >= 0; i-- { + result[i] = heap.Pop(&h).(token) + } - ts = heap - return ts + return result } // topP limits tokens to those with cumulative probability p