log probs working
This commit is contained in:
parent
f9928b677f
commit
afa2e855d4
@ -129,7 +129,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
const maxBufferSize = 512 * format.KiloByte
|
const maxBufferSize = 1024 * format.KiloByte
|
||||||
|
|
||||||
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
||||||
var buf *bytes.Buffer
|
var buf *bytes.Buffer
|
||||||
|
17
api/types.go
17
api/types.go
@ -189,11 +189,12 @@ func (t *ToolFunction) String() string {
|
|||||||
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
||||||
// similar to [GenerateResponse].
|
// similar to [GenerateResponse].
|
||||||
type ChatResponse struct {
|
type ChatResponse struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
Message Message `json:"message"`
|
Message Message `json:"message"`
|
||||||
DoneReason string `json:"done_reason,omitempty"`
|
DoneReason string `json:"done_reason,omitempty"`
|
||||||
Logits []float32 `json:"logits"`
|
Logits []float32 `json:"logits"`
|
||||||
|
TopLogprobs []TokenLogprob `json:"top_logprobs"`
|
||||||
|
|
||||||
Done bool `json:"done"`
|
Done bool `json:"done"`
|
||||||
|
|
||||||
@ -210,14 +211,10 @@ type Metrics struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type TokenLogprob struct {
|
type TokenLogprob struct {
|
||||||
Token string `json:"token"`
|
Text string `json:"text"`
|
||||||
Logprob float32 `json:"logprob"`
|
Logprob float32 `json:"logprob"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type LogProbs struct {
|
|
||||||
TopLogprobs []TokenLogprob `json:"top_logprobs"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Options specified in [GenerateRequest]. If you add a new option here, also
|
// Options specified in [GenerateRequest]. If you add a new option here, also
|
||||||
// add it to the API docs.
|
// add it to the API docs.
|
||||||
type Options struct {
|
type Options struct {
|
||||||
|
@ -273,6 +273,18 @@ func (c *Context) GetLogits() []float32 {
|
|||||||
return unsafe.Slice((*float32)(logits), vocabSize)
|
return unsafe.Slice((*float32)(logits), vocabSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) Detokenize(tokens []int) (string, error) {
|
||||||
|
var text string
|
||||||
|
for _, token := range tokens {
|
||||||
|
piece := m.TokenToPiece(token)
|
||||||
|
if piece == "" {
|
||||||
|
return "", fmt.Errorf("failed to convert token %d to piece", token)
|
||||||
|
}
|
||||||
|
text += piece
|
||||||
|
}
|
||||||
|
return text, nil
|
||||||
|
}
|
||||||
|
|
||||||
type ModelParams struct {
|
type ModelParams struct {
|
||||||
NumGpuLayers int
|
NumGpuLayers int
|
||||||
MainGpu int
|
MainGpu int
|
||||||
|
@ -15,7 +15,6 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"regexp"
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sort"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -503,9 +502,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||||||
if seq.returnLogits { // New flag we need to add to Sequence struct
|
if seq.returnLogits { // New flag we need to add to Sequence struct
|
||||||
logits := s.lc.GetLogits()
|
logits := s.lc.GetLogits()
|
||||||
seq.logits = make([]float32, len(logits))
|
seq.logits = make([]float32, len(logits))
|
||||||
slog.Info("copying logits")
|
|
||||||
copy(seq.logits, logits)
|
copy(seq.logits, logits)
|
||||||
slog.Info("copying logits success")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Then sample token
|
// Then sample token
|
||||||
@ -608,7 +605,7 @@ type CompletionRequest struct {
|
|||||||
Images []ImageData `json:"image_data"`
|
Images []ImageData `json:"image_data"`
|
||||||
Grammar string `json:"grammar"`
|
Grammar string `json:"grammar"`
|
||||||
CachePrompt bool `json:"cache_prompt"`
|
CachePrompt bool `json:"cache_prompt"`
|
||||||
ReturnLogits bool `json:"return_logits"`
|
ReturnLogits bool `json:"return_logits,omitempty"` // defaults to false
|
||||||
|
|
||||||
Options
|
Options
|
||||||
}
|
}
|
||||||
@ -729,7 +726,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
case content, ok := <-seq.responses:
|
case content, ok := <-seq.responses:
|
||||||
if ok {
|
if ok {
|
||||||
slog.Info("content", "content", content.Content)
|
// slog.Info("content", "content", content.Content)
|
||||||
if err := json.NewEncoder(w).Encode(&content); err != nil {
|
if err := json.NewEncoder(w).Encode(&content); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
close(seq.quit)
|
close(seq.quit)
|
||||||
@ -1040,50 +1037,50 @@ func Execute(args []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to get top K logits and convert to log probabilities
|
// // Helper function to get top K logits and convert to log probabilities
|
||||||
func getTopLogits(logits []float32, k int, model *llama.Model) []api.LogProbs {
|
// func getTopLogits(logits []float32, k int, model *llama.Model) []api.LogProbs {
|
||||||
if k <= 0 {
|
// if k <= 0 {
|
||||||
return nil
|
// return nil
|
||||||
}
|
// }
|
||||||
|
|
||||||
// Convert logits to probabilities using softmax
|
// // Convert logits to probabilities using softmax
|
||||||
probs := softmax(logits)
|
// probs := softmax(logits)
|
||||||
|
|
||||||
// Create slice of index/probability pairs
|
// // Create slice of index/probability pairs
|
||||||
pairs := make([]struct {
|
// pairs := make([]struct {
|
||||||
token int
|
// token int
|
||||||
prob float32
|
// prob float32
|
||||||
}, len(probs))
|
// }, len(probs))
|
||||||
|
|
||||||
for i, p := range probs {
|
// for i, p := range probs {
|
||||||
pairs[i] = struct {
|
// pairs[i] = struct {
|
||||||
token int
|
// token int
|
||||||
prob float32
|
// prob float32
|
||||||
}{i, p}
|
// }{i, p}
|
||||||
}
|
// }
|
||||||
|
|
||||||
// Sort by probability (descending)
|
// // Sort by probability (descending)
|
||||||
sort.Slice(pairs, func(i, j int) bool {
|
// sort.Slice(pairs, func(i, j int) bool {
|
||||||
return pairs[i].prob > pairs[j].prob
|
// return pairs[i].prob > pairs[j].prob
|
||||||
})
|
// })
|
||||||
|
|
||||||
// Take top K
|
// // Take top K
|
||||||
k = min(k, len(pairs))
|
// k = min(k, len(pairs))
|
||||||
result := make([]api.LogProbs, k)
|
// result := make([]api.LogProbs, k)
|
||||||
|
|
||||||
for i := 0; i < k; i++ {
|
// for i := 0; i < k; i++ {
|
||||||
result[i] = api.LogProbs{
|
// result[i] = api.LogProbs{
|
||||||
TopLogprobs: []api.TokenLogprob{
|
// TopLogprobs: []api.TokenLogprob{
|
||||||
{
|
// {
|
||||||
Token: model.TokenToPiece(pairs[i].token),
|
// Token: model.TokenToPiece(pairs[i].token),
|
||||||
Logprob: float32(math.Log(float64(pairs[i].prob))),
|
// Logprob: float32(math.Log(float64(pairs[i].prob))),
|
||||||
},
|
// },
|
||||||
},
|
// },
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
return result
|
// return result
|
||||||
}
|
// }
|
||||||
|
|
||||||
// Helper function to compute softmax
|
// Helper function to compute softmax
|
||||||
func softmax(logits []float32) []float32 {
|
func softmax(logits []float32) []float32 {
|
||||||
|
@ -19,6 +19,7 @@ import (
|
|||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
@ -299,7 +300,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
Images: images,
|
Images: images,
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
Options: opts,
|
Options: opts,
|
||||||
ReturnLogits: req.ReturnLogits,
|
ReturnLogits: false,
|
||||||
}, func(cr llm.CompletionResponse) {
|
}, func(cr llm.CompletionResponse) {
|
||||||
res := api.GenerateResponse{
|
res := api.GenerateResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
@ -1554,23 +1555,27 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
Options: opts,
|
Options: opts,
|
||||||
ReturnLogits: true,
|
ReturnLogits: true,
|
||||||
}, func(r llm.CompletionResponse) {
|
}, func(cr llm.CompletionResponse) {
|
||||||
res := api.ChatResponse{
|
res := api.ChatResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
Message: api.Message{Role: "assistant", Content: cr.Content},
|
||||||
Done: r.Done,
|
Done: cr.Done,
|
||||||
DoneReason: r.DoneReason,
|
DoneReason: cr.DoneReason,
|
||||||
Logits: r.Logits,
|
Logits: []float32{},
|
||||||
Metrics: api.Metrics{
|
Metrics: api.Metrics{
|
||||||
PromptEvalCount: r.PromptEvalCount,
|
PromptEvalCount: cr.PromptEvalCount,
|
||||||
PromptEvalDuration: r.PromptEvalDuration,
|
PromptEvalDuration: cr.PromptEvalDuration,
|
||||||
EvalCount: r.EvalCount,
|
EvalCount: cr.EvalCount,
|
||||||
EvalDuration: r.EvalDuration,
|
EvalDuration: cr.EvalDuration,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Done {
|
topK := int(3)
|
||||||
|
logits := make([]float32, len(cr.Logits))
|
||||||
|
copy(logits, cr.Logits)
|
||||||
|
res.TopLogprobs = getTopKLogProbs(c.Request.Context(), r, logits, topK)
|
||||||
|
if cr.Done {
|
||||||
res.TotalDuration = time.Since(checkpointStart)
|
res.TotalDuration = time.Since(checkpointStart)
|
||||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
}
|
}
|
||||||
@ -1586,7 +1591,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
// Streaming tool calls:
|
// Streaming tool calls:
|
||||||
// If tools are recognized, use a flag to track the sending of a tool downstream
|
// If tools are recognized, use a flag to track the sending of a tool downstream
|
||||||
// This ensures that content is cleared from the message on the last chunk sent
|
// This ensures that content is cleared from the message on the last chunk sent
|
||||||
sb.WriteString(r.Content)
|
sb.WriteString(cr.Content)
|
||||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||||
res.Message.ToolCalls = toolCalls
|
res.Message.ToolCalls = toolCalls
|
||||||
for i := range toolCalls {
|
for i := range toolCalls {
|
||||||
@ -1599,7 +1604,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Done {
|
if cr.Done {
|
||||||
// Send any remaining content if no tool calls were detected
|
// Send any remaining content if no tool calls were detected
|
||||||
if toolCallIndex == 0 {
|
if toolCallIndex == 0 {
|
||||||
res.Message.Content = sb.String()
|
res.Message.Content = sb.String()
|
||||||
@ -1649,6 +1654,48 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
streamResponse(c, ch)
|
streamResponse(c, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getTopKLogProbs(ctx context.Context, s llm.LlamaServer, logits []float32, topK int) []api.TokenLogprob {
|
||||||
|
// Calculate softmax denominator first (log sum exp trick for numerical stability)
|
||||||
|
maxLogit := float32(math.Inf(-1))
|
||||||
|
for _, logit := range logits {
|
||||||
|
if logit > maxLogit {
|
||||||
|
maxLogit = logit
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var sumExp float32
|
||||||
|
for _, logit := range logits {
|
||||||
|
sumExp += float32(math.Exp(float64(logit - maxLogit)))
|
||||||
|
}
|
||||||
|
logSumExp := float32(math.Log(float64(sumExp))) + maxLogit
|
||||||
|
|
||||||
|
// Calculate log probs and track top K
|
||||||
|
logProbs := make([]api.TokenLogprob, len(logits))
|
||||||
|
for i, logit := range logits {
|
||||||
|
text, err := s.Detokenize(ctx, []int{i})
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("detokenize error for logprob", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
logProbs[i] = api.TokenLogprob{
|
||||||
|
Text: text,
|
||||||
|
Logprob: logit - logSumExp,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by logprob descending and take top K
|
||||||
|
sort.Slice(logProbs, func(i, j int) bool {
|
||||||
|
return logProbs[i].Logprob > logProbs[j].Logprob
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(logProbs) > topK {
|
||||||
|
logProbs = logProbs[:topK]
|
||||||
|
}
|
||||||
|
|
||||||
|
return logProbs
|
||||||
|
}
|
||||||
|
|
||||||
func handleScheduleError(c *gin.Context, name string, err error) {
|
func handleScheduleError(c *gin.Context, name string, err error) {
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
|
case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user