Compare commits

...

4 Commits

Author SHA1 Message Date
ParthSareen
afa2e855d4 log probs working 2025-01-10 11:15:31 -08:00
ParthSareen
f9928b677f Working e2e logits 2025-01-03 16:06:28 -08:00
ParthSareen
c92d418a7c WIP but got logits n stuff 2025-01-02 13:32:33 -08:00
ParthSareen
d7e7e6a01e wip 2025-01-02 13:32:33 -08:00
6 changed files with 255 additions and 53 deletions

View File

@ -129,7 +129,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
return nil return nil
} }
const maxBufferSize = 512 * format.KiloByte const maxBufferSize = 1024 * format.KiloByte
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error { func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf *bytes.Buffer var buf *bytes.Buffer

View File

@ -80,6 +80,8 @@ type GenerateRequest struct {
// Options lists model-specific options. For example, temperature can be // Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it. // set through this field, if the model supports it.
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
ReturnLogits bool `json:"return_logits,omitempty"`
} }
// ChatRequest describes a request sent by [Client.Chat]. // ChatRequest describes a request sent by [Client.Chat].
@ -105,6 +107,8 @@ type ChatRequest struct {
// Options lists model-specific options. // Options lists model-specific options.
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
ReturnLogits bool `json:"return_logits,omitempty"`
} }
type Tools []Tool type Tools []Tool
@ -185,10 +189,12 @@ func (t *ToolFunction) String() string {
// ChatResponse is the response returned by [Client.Chat]. Its fields are // ChatResponse is the response returned by [Client.Chat]. Its fields are
// similar to [GenerateResponse]. // similar to [GenerateResponse].
type ChatResponse struct { type ChatResponse struct {
Model string `json:"model"` Model string `json:"model"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
Message Message `json:"message"` Message Message `json:"message"`
DoneReason string `json:"done_reason,omitempty"` DoneReason string `json:"done_reason,omitempty"`
Logits []float32 `json:"logits"`
TopLogprobs []TokenLogprob `json:"top_logprobs"`
Done bool `json:"done"` Done bool `json:"done"`
@ -204,6 +210,11 @@ type Metrics struct {
EvalDuration time.Duration `json:"eval_duration,omitempty"` EvalDuration time.Duration `json:"eval_duration,omitempty"`
} }
type TokenLogprob struct {
Text string `json:"text"`
Logprob float32 `json:"logprob"`
}
// Options specified in [GenerateRequest]. If you add a new option here, also // Options specified in [GenerateRequest]. If you add a new option here, also
// add it to the API docs. // add it to the API docs.
type Options struct { type Options struct {
@ -450,6 +461,8 @@ type GenerateResponse struct {
Context []int `json:"context,omitempty"` Context []int `json:"context,omitempty"`
Metrics Metrics
Logits []float32 `json:"logits"`
} }
// ModelDetails provides details about a model. // ModelDetails provides details about a model.

View File

@ -260,6 +260,31 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 {
return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd()) return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
} }
// GetLogits returns the logits from the last decode operation.
// The returned slice has length equal to the vocabulary size.
func (c *Context) GetLogits() []float32 {
logits := unsafe.Pointer(C.llama_get_logits(c.c))
if logits == nil {
return nil
}
// Get the number of vocabulary tokens to determine array size
vocabSize := c.Model().NumVocab()
return unsafe.Slice((*float32)(logits), vocabSize)
}
func (m *Model) Detokenize(tokens []int) (string, error) {
var text string
for _, token := range tokens {
piece := m.TokenToPiece(token)
if piece == "" {
return "", fmt.Errorf("failed to convert token %d to piece", token)
}
text += piece
}
return text, nil
}
type ModelParams struct { type ModelParams struct {
NumGpuLayers int NumGpuLayers int
MainGpu int MainGpu int

View File

@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"log" "log"
"log/slog" "log/slog"
"math"
"net" "net"
"net/http" "net/http"
"os" "os"
@ -59,7 +60,7 @@ type Sequence struct {
crossAttention bool crossAttention bool
// channel to send responses over // channel to send responses over
responses chan string responses chan CompletionResponse
// channel to stop decoding (such as if the remote connection is closed) // channel to stop decoding (such as if the remote connection is closed)
quit chan bool quit chan bool
@ -88,6 +89,15 @@ type Sequence struct {
startGenerationTime time.Time startGenerationTime time.Time
numDecoded int numDecoded int
numPromptInputs int numPromptInputs int
// New flag we need to add to Sequence struct
returnLogits bool
// Using our new GetLogits() method
logits []float32
// Add new channel for logits
logitsOut chan []float32
} }
type NewSequenceParams struct { type NewSequenceParams struct {
@ -96,6 +106,7 @@ type NewSequenceParams struct {
numKeep int numKeep int
samplingParams *llama.SamplingParams samplingParams *llama.SamplingParams
embedding bool embedding bool
returnLogits bool
} }
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) { func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
@ -149,13 +160,15 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
startProcessingTime: startTime, startProcessingTime: startTime,
numPredict: params.numPredict, numPredict: params.numPredict,
pendingResponses: make([]string, 0), pendingResponses: make([]string, 0),
responses: make(chan string, 100), responses: make(chan CompletionResponse, 100),
quit: make(chan bool, 1), quit: make(chan bool, 1),
embedding: make(chan []float32, 1), embedding: make(chan []float32, 1),
samplingCtx: sc, samplingCtx: sc,
embeddingOnly: params.embedding, embeddingOnly: params.embedding,
stop: params.stop, stop: params.stop,
numKeep: params.numKeep, numKeep: params.numKeep,
returnLogits: params.returnLogits,
logitsOut: make(chan []float32, 100),
}, nil }, nil
} }
@ -274,25 +287,34 @@ func (s *Server) allNil() bool {
} }
func flushPending(seq *Sequence) bool { func flushPending(seq *Sequence) bool {
joined := strings.Join(seq.pendingResponses, "") if len(seq.pendingResponses) == 0 {
seq.pendingResponses = []string{} return true
}
content := strings.Join(seq.pendingResponses, "")
// Check if there are any partial UTF-8 characters remaining. // Check if there are any partial UTF-8 characters remaining.
// We already check and queue as we are generating but some may // We already check and queue as we are generating but some may
// still make it here: // still make it here:
// - Sequence is ending, e.g. generation limit has been hit // - Sequence is ending, e.g. generation limit has been hit
// - Invalid characters in the middle of a string // - Invalid characters in the middle of a string
// This is a stricter check to ensure we never output invalid Unicode. // This is a stricter check to ensure we never output invalid Unicode.
for !utf8.ValidString(joined) { for !utf8.ValidString(content) {
joined = joined[:len(joined)-1] content = content[:len(content)-1]
}
seq.pendingResponses = nil
resp := CompletionResponse{
Content: content,
} }
if len(joined) == 0 { // Add logits if requested and available
return true if seq.returnLogits && seq.logits != nil {
resp.Logits = seq.logits
seq.logits = nil
} }
select { select {
case seq.responses <- joined: case seq.responses <- resp:
return true return true
case <-seq.quit: case <-seq.quit:
return false return false
@ -476,7 +498,14 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
continue continue
} }
// sample a token // Before sampling:
if seq.returnLogits { // New flag we need to add to Sequence struct
logits := s.lc.GetLogits()
seq.logits = make([]float32, len(logits))
copy(seq.logits, logits)
}
// Then sample token
token := seq.samplingCtx.Sample(s.lc, seq.iBatch) token := seq.samplingCtx.Sample(s.lc, seq.iBatch)
seq.samplingCtx.Accept(token, true) seq.samplingCtx.Accept(token, true)
piece := s.model.TokenToPiece(token) piece := s.model.TokenToPiece(token)
@ -572,10 +601,11 @@ type ImageData struct {
} }
type CompletionRequest struct { type CompletionRequest struct {
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Images []ImageData `json:"image_data"` Images []ImageData `json:"image_data"`
Grammar string `json:"grammar"` Grammar string `json:"grammar"`
CachePrompt bool `json:"cache_prompt"` CachePrompt bool `json:"cache_prompt"`
ReturnLogits bool `json:"return_logits,omitempty"` // defaults to false
Options Options
} }
@ -588,8 +618,10 @@ type Timings struct {
} }
type CompletionResponse struct { type CompletionResponse struct {
Content string `json:"content"` Content string `json:"content"`
Stop bool `json:"stop"` Logits []float32 `json:"logits,omitempty"`
Tokens []string `json:"tokens,omitempty"`
Stop bool `json:"stop"`
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Prompt string `json:"prompt,omitempty"` Prompt string `json:"prompt,omitempty"`
@ -637,12 +669,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
samplingParams.Seed = uint32(req.Seed) samplingParams.Seed = uint32(req.Seed)
samplingParams.Grammar = req.Grammar samplingParams.Grammar = req.Grammar
slog.Info("completion request", "return_logits", req.ReturnLogits)
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.NumPredict, numPredict: req.NumPredict,
stop: req.Stop, stop: req.Stop,
numKeep: req.NumKeep, numKeep: req.NumKeep,
samplingParams: &samplingParams, samplingParams: &samplingParams,
embedding: false, embedding: false,
returnLogits: req.ReturnLogits,
}) })
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
@ -692,9 +726,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
case content, ok := <-seq.responses: case content, ok := <-seq.responses:
if ok { if ok {
if err := json.NewEncoder(w).Encode(&CompletionResponse{ // slog.Info("content", "content", content.Content)
Content: content, if err := json.NewEncoder(w).Encode(&content); err != nil {
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
close(seq.quit) close(seq.quit)
return return
@ -1003,3 +1036,76 @@ func Execute(args []string) error {
cancel() cancel()
return nil return nil
} }
// // Helper function to get top K logits and convert to log probabilities
// func getTopLogits(logits []float32, k int, model *llama.Model) []api.LogProbs {
// if k <= 0 {
// return nil
// }
// // Convert logits to probabilities using softmax
// probs := softmax(logits)
// // Create slice of index/probability pairs
// pairs := make([]struct {
// token int
// prob float32
// }, len(probs))
// for i, p := range probs {
// pairs[i] = struct {
// token int
// prob float32
// }{i, p}
// }
// // Sort by probability (descending)
// sort.Slice(pairs, func(i, j int) bool {
// return pairs[i].prob > pairs[j].prob
// })
// // Take top K
// k = min(k, len(pairs))
// result := make([]api.LogProbs, k)
// for i := 0; i < k; i++ {
// result[i] = api.LogProbs{
// TopLogprobs: []api.TokenLogprob{
// {
// Token: model.TokenToPiece(pairs[i].token),
// Logprob: float32(math.Log(float64(pairs[i].prob))),
// },
// },
// }
// }
// return result
// }
// Helper function to compute softmax
func softmax(logits []float32) []float32 {
probs := make([]float32, len(logits))
// Find max for numerical stability
max := float32(math.Inf(-1))
for _, l := range logits {
if l > max {
max = l
}
}
// Compute exp(x - max) and sum
sum := float32(0)
for i, l := range logits {
ex := float32(math.Exp(float64(l - max)))
probs[i] = ex
sum += ex
}
// Normalize
for i := range probs {
probs[i] /= sum
}
return probs
}

View File

@ -633,7 +633,8 @@ number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
ws ::= ([ \t\n] ws)? ws ::= ([ \t\n] ws)?
` `
const maxBufferSize = 512 * format.KiloByte // TODO: change back to 512 * format.KiloByte
const maxBufferSize = 2048 * format.KiloByte
type ImageData struct { type ImageData struct {
Data []byte `json:"data"` Data []byte `json:"data"`
@ -642,11 +643,12 @@ type ImageData struct {
} }
type completion struct { type completion struct {
Content string `json:"content"` Content string `json:"content"`
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Stop bool `json:"stop"` Stop bool `json:"stop"`
StoppedLimit bool `json:"stopped_limit"` StoppedLimit bool `json:"stopped_limit"`
Logits []float32 `json:"logits,omitempty"`
Timings struct { Timings struct {
PredictedN int `json:"predicted_n"` PredictedN int `json:"predicted_n"`
@ -657,10 +659,11 @@ type completion struct {
} }
type CompletionRequest struct { type CompletionRequest struct {
Prompt string Prompt string
Format json.RawMessage Format json.RawMessage
Images []ImageData Images []ImageData
Options *api.Options Options *api.Options
ReturnLogits bool
} }
type CompletionResponse struct { type CompletionResponse struct {
@ -671,6 +674,7 @@ type CompletionResponse struct {
PromptEvalDuration time.Duration PromptEvalDuration time.Duration
EvalCount int EvalCount int
EvalDuration time.Duration EvalDuration time.Duration
Logits []float32
} }
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
@ -696,6 +700,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
"seed": req.Options.Seed, "seed": req.Options.Seed,
"stop": req.Options.Stop, "stop": req.Options.Stop,
"image_data": req.Images, "image_data": req.Images,
"return_logits": req.ReturnLogits,
"cache_prompt": true, "cache_prompt": true,
} }
@ -821,6 +826,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
if c.Content != "" { if c.Content != "" {
fn(CompletionResponse{ fn(CompletionResponse{
Content: c.Content, Content: c.Content,
Logits: c.Logits,
}) })
} }
@ -837,6 +843,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS), PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
EvalCount: c.Timings.PredictedN, EvalCount: c.Timings.PredictedN,
EvalDuration: parseDurationMs(c.Timings.PredictedMS), EvalDuration: parseDurationMs(c.Timings.PredictedMS),
Logits: c.Logits,
}) })
return nil return nil
} }

View File

@ -19,6 +19,7 @@ import (
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"slices" "slices"
"sort"
"strings" "strings"
"syscall" "syscall"
"time" "time"
@ -295,10 +296,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var sb strings.Builder var sb strings.Builder
defer close(ch) defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Images: images, Images: images,
Format: req.Format, Format: req.Format,
Options: opts, Options: opts,
ReturnLogits: false,
}, func(cr llm.CompletionResponse) { }, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{ res := api.GenerateResponse{
Model: req.Model, Model: req.Model,
@ -312,6 +314,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
EvalCount: cr.EvalCount, EvalCount: cr.EvalCount,
EvalDuration: cr.EvalDuration, EvalDuration: cr.EvalDuration,
}, },
Logits: cr.Logits,
} }
if _, err := sb.WriteString(cr.Content); err != nil { if _, err := sb.WriteString(cr.Content); err != nil {
@ -1547,26 +1550,32 @@ func (s *Server) ChatHandler(c *gin.Context) {
var sb strings.Builder var sb strings.Builder
var toolCallIndex int = 0 var toolCallIndex int = 0
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Images: images, Images: images,
Format: req.Format, Format: req.Format,
Options: opts, Options: opts,
}, func(r llm.CompletionResponse) { ReturnLogits: true,
}, func(cr llm.CompletionResponse) {
res := api.ChatResponse{ res := api.ChatResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content}, Message: api.Message{Role: "assistant", Content: cr.Content},
Done: r.Done, Done: cr.Done,
DoneReason: r.DoneReason, DoneReason: cr.DoneReason,
Logits: []float32{},
Metrics: api.Metrics{ Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount, PromptEvalCount: cr.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration, PromptEvalDuration: cr.PromptEvalDuration,
EvalCount: r.EvalCount, EvalCount: cr.EvalCount,
EvalDuration: r.EvalDuration, EvalDuration: cr.EvalDuration,
}, },
} }
if r.Done { topK := int(3)
logits := make([]float32, len(cr.Logits))
copy(logits, cr.Logits)
res.TopLogprobs = getTopKLogProbs(c.Request.Context(), r, logits, topK)
if cr.Done {
res.TotalDuration = time.Since(checkpointStart) res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
} }
@ -1582,7 +1591,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
// Streaming tool calls: // Streaming tool calls:
// If tools are recognized, use a flag to track the sending of a tool downstream // If tools are recognized, use a flag to track the sending of a tool downstream
// This ensures that content is cleared from the message on the last chunk sent // This ensures that content is cleared from the message on the last chunk sent
sb.WriteString(r.Content) sb.WriteString(cr.Content)
if toolCalls, ok := m.parseToolCalls(sb.String()); ok { if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
res.Message.ToolCalls = toolCalls res.Message.ToolCalls = toolCalls
for i := range toolCalls { for i := range toolCalls {
@ -1595,7 +1604,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
if r.Done { if cr.Done {
// Send any remaining content if no tool calls were detected // Send any remaining content if no tool calls were detected
if toolCallIndex == 0 { if toolCallIndex == 0 {
res.Message.Content = sb.String() res.Message.Content = sb.String()
@ -1645,6 +1654,48 @@ func (s *Server) ChatHandler(c *gin.Context) {
streamResponse(c, ch) streamResponse(c, ch)
} }
func getTopKLogProbs(ctx context.Context, s llm.LlamaServer, logits []float32, topK int) []api.TokenLogprob {
// Calculate softmax denominator first (log sum exp trick for numerical stability)
maxLogit := float32(math.Inf(-1))
for _, logit := range logits {
if logit > maxLogit {
maxLogit = logit
}
}
var sumExp float32
for _, logit := range logits {
sumExp += float32(math.Exp(float64(logit - maxLogit)))
}
logSumExp := float32(math.Log(float64(sumExp))) + maxLogit
// Calculate log probs and track top K
logProbs := make([]api.TokenLogprob, len(logits))
for i, logit := range logits {
text, err := s.Detokenize(ctx, []int{i})
if err != nil {
slog.Error("detokenize error for logprob", "error", err)
continue
}
logProbs[i] = api.TokenLogprob{
Text: text,
Logprob: logit - logSumExp,
}
}
// Sort by logprob descending and take top K
sort.Slice(logProbs, func(i, j int) bool {
return logProbs[i].Logprob > logProbs[j].Logprob
})
if len(logProbs) > topK {
logProbs = logProbs[:topK]
}
return logProbs
}
func handleScheduleError(c *gin.Context, name string, err error) { func handleScheduleError(c *gin.Context, name string, err error) {
switch { switch {
case errors.Is(err, errCapabilities), errors.Is(err, errRequired): case errors.Is(err, errCapabilities), errors.Is(err, errRequired):