diff --git a/sample/transforms.go b/sample/transforms.go index 82dbb1fec..0c36bda66 100644 --- a/sample/transforms.go +++ b/sample/transforms.go @@ -126,61 +126,77 @@ func minP(ts []token, p float32) []token { return ts } -// TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584 -// sortLogits sorts implementation to sort tokens by logits using counting sort -// counting sort is faster than built-in sort for this use case -func sortLogits(tokens []token) { - if len(tokens) <= 1 { - return +// partialSortLogits uses quickselect to efficiently find and sort the top n tokens +func partialSortLogits(ts []token, n int) []token { + if n >= len(ts) { + n = len(ts) } - // Find max/min in a single pass - minLogit, maxLogit := tokens[0].value, tokens[0].value - for _, t := range tokens[1:] { - if t.value < minLogit { - minLogit = t.value - } else if t.value > maxLogit { - maxLogit = t.value + left, right := 0, len(ts)-1 + target := n - 1 + + // Quickselect algorithm to partition array around pivot + for left < right { + // Choose middle element as pivot and move it to the end + pivot := left + (right-left)/2 + ts[pivot], ts[right] = ts[right], ts[pivot] + + // storeIndex tracks where to put next element greater than pivot + storeIndex := left + pivotValue := ts[right].value + + // Partition array into elements >= pivot and < pivot + // Elements >= pivot go to the left side + for i := left; i < right; i++ { + if ts[i].value >= pivotValue { + ts[storeIndex], ts[i] = ts[i], ts[storeIndex] + storeIndex++ } } - // Calculate scaling to map to uint32 range - logitRange := maxLogit - minLogit - if logitRange < 1e-6 { - return // All values effectively equal + // Move pivot to its final position + ts[right], ts[storeIndex] = ts[storeIndex], ts[right] + + // If pivot is at target position, we're done + // Otherwise recursively partition the half containing target + if storeIndex == target { + break + } else if storeIndex < target { + left = storeIndex + 1 // Target is in right half + } else { + right = storeIndex - 1 // Target is in left half + } } - // Count frequencies directly from tokens - const maxInt = (1 << 24) - 1 // Use 24 bits for good granularity - var counts [256]int // For first byte + // Sort just the top n elements in descending order + slices.SortFunc(ts[:n], func(a, b token) int { + if a.value > b.value { + return -1 + } + if a.value < b.value { + return 1 + } + return 0 + }) - // First pass: count frequencies - for _, t := range tokens { - // Map to [0, maxInt] range - score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt) - counts[score>>16]++ + return ts[:n] } - // Calculate offsets - var offset int - for i := range counts { - count := counts[i] - counts[i] = offset - offset += count +// sortLogits uses partialSortLogits to efficiently sort tokens +// It sorts approximately sqrt(len(tokens)) elements which balances +// between having enough tokens for sampling while avoiding full sort +func sortLogits(ts []token) { + // Use sqrt of token length as a heuristic for partial sort size + // This provides a good balance between performance and having enough tokens + n := int(math.Sqrt(float64(len(ts)))) + 1 + + // Ensure we have at least 100 tokens and at most 1000 + switch { + case n < 100: + n = 100 + case n > 1000: + n = 1000 } - // Second pass: place elements in correct position - output := make([]token, len(tokens)) - // Track current positions - countsCopy := counts - - for i, t := range tokens { - score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt) - - pos := countsCopy[score>>16] - countsCopy[score>>16]++ - output[len(tokens)-1-pos] = tokens[i] - } - - copy(tokens, output) + partialSortLogits(ts, n) }