sample: use partial sort for sorting

This commit is contained in:
ParthSareen 2025-03-12 00:46:12 -04:00
parent 448fc4cd2a
commit 310b235626

View File

@ -126,61 +126,77 @@ func minP(ts []token, p float32) []token {
return ts return ts
} }
// TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584 // partialSortLogits uses quickselect to efficiently find and sort the top n tokens
// sortLogits sorts implementation to sort tokens by logits using counting sort func partialSortLogits(ts []token, n int) []token {
// counting sort is faster than built-in sort for this use case if n >= len(ts) {
func sortLogits(tokens []token) { n = len(ts)
if len(tokens) <= 1 {
return
} }
// Find max/min in a single pass left, right := 0, len(ts)-1
minLogit, maxLogit := tokens[0].value, tokens[0].value target := n - 1
for _, t := range tokens[1:] {
if t.value < minLogit { // Quickselect algorithm to partition array around pivot
minLogit = t.value for left < right {
} else if t.value > maxLogit { // Choose middle element as pivot and move it to the end
maxLogit = t.value pivot := left + (right-left)/2
ts[pivot], ts[right] = ts[right], ts[pivot]
// storeIndex tracks where to put next element greater than pivot
storeIndex := left
pivotValue := ts[right].value
// Partition array into elements >= pivot and < pivot
// Elements >= pivot go to the left side
for i := left; i < right; i++ {
if ts[i].value >= pivotValue {
ts[storeIndex], ts[i] = ts[i], ts[storeIndex]
storeIndex++
} }
} }
// Calculate scaling to map to uint32 range // Move pivot to its final position
logitRange := maxLogit - minLogit ts[right], ts[storeIndex] = ts[storeIndex], ts[right]
if logitRange < 1e-6 {
return // All values effectively equal // If pivot is at target position, we're done
// Otherwise recursively partition the half containing target
if storeIndex == target {
break
} else if storeIndex < target {
left = storeIndex + 1 // Target is in right half
} else {
right = storeIndex - 1 // Target is in left half
}
} }
// Count frequencies directly from tokens // Sort just the top n elements in descending order
const maxInt = (1 << 24) - 1 // Use 24 bits for good granularity slices.SortFunc(ts[:n], func(a, b token) int {
var counts [256]int // For first byte if a.value > b.value {
return -1
}
if a.value < b.value {
return 1
}
return 0
})
// First pass: count frequencies return ts[:n]
for _, t := range tokens {
// Map to [0, maxInt] range
score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt)
counts[score>>16]++
} }
// Calculate offsets // sortLogits uses partialSortLogits to efficiently sort tokens
var offset int // It sorts approximately sqrt(len(tokens)) elements which balances
for i := range counts { // between having enough tokens for sampling while avoiding full sort
count := counts[i] func sortLogits(ts []token) {
counts[i] = offset // Use sqrt of token length as a heuristic for partial sort size
offset += count // This provides a good balance between performance and having enough tokens
n := int(math.Sqrt(float64(len(ts)))) + 1
// Ensure we have at least 100 tokens and at most 1000
switch {
case n < 100:
n = 100
case n > 1000:
n = 1000
} }
// Second pass: place elements in correct position partialSortLogits(ts, n)
output := make([]token, len(tokens))
// Track current positions
countsCopy := counts
for i, t := range tokens {
score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt)
pos := countsCopy[score>>16]
countsCopy[score>>16]++
output[len(tokens)-1-pos] = tokens[i]
}
copy(tokens, output)
} }