print logprobs
This commit is contained in:
parent
82658c3eec
commit
7d16ec8fe8
@ -50,7 +50,7 @@ import (
|
|||||||
_ "github.com/ollama/ollama/llama/llama.cpp/common"
|
_ "github.com/ollama/ollama/llama/llama.cpp/common"
|
||||||
_ "github.com/ollama/ollama/llama/llama.cpp/examples/llava"
|
_ "github.com/ollama/ollama/llama/llama.cpp/examples/llava"
|
||||||
_ "github.com/ollama/ollama/llama/llama.cpp/src"
|
_ "github.com/ollama/ollama/llama/llama.cpp/src"
|
||||||
"github.com/ollama/ollama/ml/backend/ggml/ggml/src"
|
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BackendInit() {
|
func BackendInit() {
|
||||||
@ -220,6 +220,31 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 {
|
|||||||
return embeddings
|
return embeddings
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetLogits returns the logits from the last decode operation.
|
||||||
|
// The returned slice has length equal to the vocabulary size.
|
||||||
|
func (c *Context) GetLogits() []float32 {
|
||||||
|
logits := unsafe.Pointer(C.llama_get_logits(c.c))
|
||||||
|
if logits == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the number of vocabulary tokens to determine array size
|
||||||
|
vocabSize := c.Model().NumVocab()
|
||||||
|
return unsafe.Slice((*float32)(logits), vocabSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Detokenize(tokens []int) (string, error) {
|
||||||
|
var text string
|
||||||
|
for _, token := range tokens {
|
||||||
|
piece := m.TokenToPiece(token)
|
||||||
|
if piece == "" {
|
||||||
|
return "", fmt.Errorf("failed to convert token %d to piece", token)
|
||||||
|
}
|
||||||
|
text += piece
|
||||||
|
}
|
||||||
|
return text, nil
|
||||||
|
}
|
||||||
|
|
||||||
type ModelParams struct {
|
type ModelParams struct {
|
||||||
NumGpuLayers int
|
NumGpuLayers int
|
||||||
MainGpu int
|
MainGpu int
|
||||||
|
@ -8,12 +8,14 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"regexp"
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -83,6 +85,8 @@ type Sequence struct {
|
|||||||
|
|
||||||
doneReason string
|
doneReason string
|
||||||
|
|
||||||
|
logits []float32
|
||||||
|
|
||||||
// Metrics
|
// Metrics
|
||||||
startProcessingTime time.Time
|
startProcessingTime time.Time
|
||||||
startGenerationTime time.Time
|
startGenerationTime time.Time
|
||||||
@ -274,6 +278,9 @@ func (s *Server) allNil() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func flushPending(seq *Sequence) bool {
|
func flushPending(seq *Sequence) bool {
|
||||||
|
if len(seq.pendingResponses) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
joined := strings.Join(seq.pendingResponses, "")
|
joined := strings.Join(seq.pendingResponses, "")
|
||||||
seq.pendingResponses = []string{}
|
seq.pendingResponses = []string{}
|
||||||
|
|
||||||
@ -287,8 +294,11 @@ func flushPending(seq *Sequence) bool {
|
|||||||
joined = joined[:len(joined)-1]
|
joined = joined[:len(joined)-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(joined) == 0 {
|
// Add logits if requested and available
|
||||||
return true
|
wantLogits := true
|
||||||
|
if wantLogits && seq.logits != nil {
|
||||||
|
// resp.Logits = seq.logits
|
||||||
|
seq.logits = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -350,6 +360,57 @@ func (s *Server) run(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TokenData represents probability information for a token
|
||||||
|
type TokenData struct {
|
||||||
|
TokenID int
|
||||||
|
Logit float32
|
||||||
|
Prob float32
|
||||||
|
LogProb float32
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTokenProbabilities returns sorted token probabilities for a specific token index
|
||||||
|
func (s *Server) getTokenProbabilities(seq *Sequence) []TokenData {
|
||||||
|
// Get logits for the specific token index
|
||||||
|
logits := s.lc.GetLogits()
|
||||||
|
seq.logits = make([]float32, len(logits))
|
||||||
|
copy(seq.logits, logits)
|
||||||
|
|
||||||
|
vocabSize := s.model.NumVocab()
|
||||||
|
probs := make([]TokenData, vocabSize)
|
||||||
|
|
||||||
|
// Initialize token data with logits
|
||||||
|
for i := 0; i < vocabSize; i++ {
|
||||||
|
probs[i] = TokenData{
|
||||||
|
TokenID: i,
|
||||||
|
Logit: logits[i],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort tokens by logits in descending order
|
||||||
|
sort.Slice(probs, func(i, j int) bool {
|
||||||
|
return probs[i].Logit > probs[j].Logit
|
||||||
|
})
|
||||||
|
|
||||||
|
// Apply softmax
|
||||||
|
maxLogit := probs[0].Logit
|
||||||
|
var sum float32 = 0.0
|
||||||
|
|
||||||
|
for i := range probs {
|
||||||
|
p := float32(math.Exp(float64(probs[i].Logit - maxLogit)))
|
||||||
|
probs[i].Prob = p
|
||||||
|
sum += p
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize probabilities and calculate log probs
|
||||||
|
for i := range probs {
|
||||||
|
prob := probs[i].Prob / sum
|
||||||
|
probs[i].Prob = prob
|
||||||
|
probs[i].LogProb = float32(math.Log(float64(prob)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return probs
|
||||||
|
}
|
||||||
|
|
||||||
// TODO (jmorganca): processBatch should be simplified, removing:
|
// TODO (jmorganca): processBatch should be simplified, removing:
|
||||||
// * sampling
|
// * sampling
|
||||||
// * stop token checking
|
// * stop token checking
|
||||||
@ -483,6 +544,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||||||
|
|
||||||
seq.numPredicted++
|
seq.numPredicted++
|
||||||
|
|
||||||
|
// TODO: only do this when flag specified
|
||||||
|
probs := s.getTokenProbabilities(seq)
|
||||||
|
for i := range 10 {
|
||||||
|
slog.Debug("top 10 tokens", "token", probs[i].TokenID, "prob", probs[i].Prob, "logit", probs[i].Logit, "piece", s.model.TokenToPiece(probs[i].TokenID))
|
||||||
|
}
|
||||||
|
|
||||||
// if it's an end of sequence token, break
|
// if it's an end of sequence token, break
|
||||||
if s.model.TokenIsEog(token) {
|
if s.model.TokenIsEog(token) {
|
||||||
// TODO (jmorganca): we should send this back
|
// TODO (jmorganca): we should send this back
|
||||||
|
Loading…
x
Reference in New Issue
Block a user