diff --git a/llama/llama.go b/llama/llama.go index c11d53411..acbf0a672 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -737,3 +737,14 @@ func SchemaToGrammar(schema []byte) []byte { } return buf[:n] } + +// GetLogits returns the logits from the last decode operation. +// The returned slice has length equal to the vocabulary size. +func (c *Context) GetLogits() []float32 { + logits := unsafe.Pointer(C.llama_get_logits(c.c)) + if logits == nil { + return nil + } + + // Get the number of vocabulary tokens to determine array size + vocabSize := c.Model().NumVocab() diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 86c010096..e7ed3ba6d 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -1003,3 +1003,76 @@ func Execute(args []string) error { cancel() return nil } + +// Helper function to get top K logits and convert to log probabilities +func getTopLogits(logits []float32, k int, model *llama.Model) []api.LogProbs { + if k <= 0 { + return nil + } + + // Convert logits to probabilities using softmax + probs := softmax(logits) + + // Create slice of index/probability pairs + pairs := make([]struct { + token int + prob float32 + }, len(probs)) + + for i, p := range probs { + pairs[i] = struct { + token int + prob float32 + }{i, p} + } + + // Sort by probability (descending) + sort.Slice(pairs, func(i, j int) bool { + return pairs[i].prob > pairs[j].prob + }) + + // Take top K + k = min(k, len(pairs)) + result := make([]api.LogProbs, k) + + for i := 0; i < k; i++ { + result[i] = api.LogProbs{ + TopLogprobs: []api.TokenLogprob{ + { + Token: model.TokenToPiece(pairs[i].token), + Logprob: float32(math.Log(float64(pairs[i].prob))), + }, + }, + } + } + + return result +} + +// Helper function to compute softmax +func softmax(logits []float32) []float32 { + probs := make([]float32, len(logits)) + + // Find max for numerical stability + max := float32(math.Inf(-1)) + for _, l := range logits { + if l > max { + max = l + } + } + + // Compute exp(x - max) and sum + sum := float32(0) + for i, l := range logits { + ex := float32(math.Exp(float64(l - max))) + probs[i] = ex + sum += ex + } + + // Normalize + for i := range probs { + probs[i] /= sum + } + + return probs +}