Compare commits
4 Commits
main
...
parth/log-
Author | SHA1 | Date | |
---|---|---|---|
![]() |
afa2e855d4 | ||
![]() |
f9928b677f | ||
![]() |
c92d418a7c | ||
![]() |
d7e7e6a01e |
@ -129,7 +129,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
||||
return nil
|
||||
}
|
||||
|
||||
const maxBufferSize = 512 * format.KiloByte
|
||||
const maxBufferSize = 1024 * format.KiloByte
|
||||
|
||||
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
||||
var buf *bytes.Buffer
|
||||
|
21
api/types.go
21
api/types.go
@ -80,6 +80,8 @@ type GenerateRequest struct {
|
||||
// Options lists model-specific options. For example, temperature can be
|
||||
// set through this field, if the model supports it.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
|
||||
ReturnLogits bool `json:"return_logits,omitempty"`
|
||||
}
|
||||
|
||||
// ChatRequest describes a request sent by [Client.Chat].
|
||||
@ -105,6 +107,8 @@ type ChatRequest struct {
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
|
||||
ReturnLogits bool `json:"return_logits,omitempty"`
|
||||
}
|
||||
|
||||
type Tools []Tool
|
||||
@ -185,10 +189,12 @@ func (t *ToolFunction) String() string {
|
||||
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
||||
// similar to [GenerateResponse].
|
||||
type ChatResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Message Message `json:"message"`
|
||||
DoneReason string `json:"done_reason,omitempty"`
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Message Message `json:"message"`
|
||||
DoneReason string `json:"done_reason,omitempty"`
|
||||
Logits []float32 `json:"logits"`
|
||||
TopLogprobs []TokenLogprob `json:"top_logprobs"`
|
||||
|
||||
Done bool `json:"done"`
|
||||
|
||||
@ -204,6 +210,11 @@ type Metrics struct {
|
||||
EvalDuration time.Duration `json:"eval_duration,omitempty"`
|
||||
}
|
||||
|
||||
type TokenLogprob struct {
|
||||
Text string `json:"text"`
|
||||
Logprob float32 `json:"logprob"`
|
||||
}
|
||||
|
||||
// Options specified in [GenerateRequest]. If you add a new option here, also
|
||||
// add it to the API docs.
|
||||
type Options struct {
|
||||
@ -450,6 +461,8 @@ type GenerateResponse struct {
|
||||
Context []int `json:"context,omitempty"`
|
||||
|
||||
Metrics
|
||||
|
||||
Logits []float32 `json:"logits"`
|
||||
}
|
||||
|
||||
// ModelDetails provides details about a model.
|
||||
|
@ -260,6 +260,31 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 {
|
||||
return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
|
||||
}
|
||||
|
||||
// GetLogits returns the logits from the last decode operation.
|
||||
// The returned slice has length equal to the vocabulary size.
|
||||
func (c *Context) GetLogits() []float32 {
|
||||
logits := unsafe.Pointer(C.llama_get_logits(c.c))
|
||||
if logits == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the number of vocabulary tokens to determine array size
|
||||
vocabSize := c.Model().NumVocab()
|
||||
return unsafe.Slice((*float32)(logits), vocabSize)
|
||||
}
|
||||
|
||||
func (m *Model) Detokenize(tokens []int) (string, error) {
|
||||
var text string
|
||||
for _, token := range tokens {
|
||||
piece := m.TokenToPiece(token)
|
||||
if piece == "" {
|
||||
return "", fmt.Errorf("failed to convert token %d to piece", token)
|
||||
}
|
||||
text += piece
|
||||
}
|
||||
return text, nil
|
||||
}
|
||||
|
||||
type ModelParams struct {
|
||||
NumGpuLayers int
|
||||
MainGpu int
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
@ -59,7 +60,7 @@ type Sequence struct {
|
||||
crossAttention bool
|
||||
|
||||
// channel to send responses over
|
||||
responses chan string
|
||||
responses chan CompletionResponse
|
||||
|
||||
// channel to stop decoding (such as if the remote connection is closed)
|
||||
quit chan bool
|
||||
@ -88,6 +89,15 @@ type Sequence struct {
|
||||
startGenerationTime time.Time
|
||||
numDecoded int
|
||||
numPromptInputs int
|
||||
|
||||
// New flag we need to add to Sequence struct
|
||||
returnLogits bool
|
||||
|
||||
// Using our new GetLogits() method
|
||||
logits []float32
|
||||
|
||||
// Add new channel for logits
|
||||
logitsOut chan []float32
|
||||
}
|
||||
|
||||
type NewSequenceParams struct {
|
||||
@ -96,6 +106,7 @@ type NewSequenceParams struct {
|
||||
numKeep int
|
||||
samplingParams *llama.SamplingParams
|
||||
embedding bool
|
||||
returnLogits bool
|
||||
}
|
||||
|
||||
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
@ -149,13 +160,15 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||
startProcessingTime: startTime,
|
||||
numPredict: params.numPredict,
|
||||
pendingResponses: make([]string, 0),
|
||||
responses: make(chan string, 100),
|
||||
responses: make(chan CompletionResponse, 100),
|
||||
quit: make(chan bool, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
samplingCtx: sc,
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
returnLogits: params.returnLogits,
|
||||
logitsOut: make(chan []float32, 100),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -274,25 +287,34 @@ func (s *Server) allNil() bool {
|
||||
}
|
||||
|
||||
func flushPending(seq *Sequence) bool {
|
||||
joined := strings.Join(seq.pendingResponses, "")
|
||||
seq.pendingResponses = []string{}
|
||||
if len(seq.pendingResponses) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
content := strings.Join(seq.pendingResponses, "")
|
||||
// Check if there are any partial UTF-8 characters remaining.
|
||||
// We already check and queue as we are generating but some may
|
||||
// still make it here:
|
||||
// - Sequence is ending, e.g. generation limit has been hit
|
||||
// - Invalid characters in the middle of a string
|
||||
// This is a stricter check to ensure we never output invalid Unicode.
|
||||
for !utf8.ValidString(joined) {
|
||||
joined = joined[:len(joined)-1]
|
||||
for !utf8.ValidString(content) {
|
||||
content = content[:len(content)-1]
|
||||
}
|
||||
seq.pendingResponses = nil
|
||||
|
||||
resp := CompletionResponse{
|
||||
Content: content,
|
||||
}
|
||||
|
||||
if len(joined) == 0 {
|
||||
return true
|
||||
// Add logits if requested and available
|
||||
if seq.returnLogits && seq.logits != nil {
|
||||
resp.Logits = seq.logits
|
||||
seq.logits = nil
|
||||
}
|
||||
|
||||
select {
|
||||
case seq.responses <- joined:
|
||||
case seq.responses <- resp:
|
||||
return true
|
||||
case <-seq.quit:
|
||||
return false
|
||||
@ -476,7 +498,14 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
continue
|
||||
}
|
||||
|
||||
// sample a token
|
||||
// Before sampling:
|
||||
if seq.returnLogits { // New flag we need to add to Sequence struct
|
||||
logits := s.lc.GetLogits()
|
||||
seq.logits = make([]float32, len(logits))
|
||||
copy(seq.logits, logits)
|
||||
}
|
||||
|
||||
// Then sample token
|
||||
token := seq.samplingCtx.Sample(s.lc, seq.iBatch)
|
||||
seq.samplingCtx.Accept(token, true)
|
||||
piece := s.model.TokenToPiece(token)
|
||||
@ -572,10 +601,11 @@ type ImageData struct {
|
||||
}
|
||||
|
||||
type CompletionRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Images []ImageData `json:"image_data"`
|
||||
Grammar string `json:"grammar"`
|
||||
CachePrompt bool `json:"cache_prompt"`
|
||||
Prompt string `json:"prompt"`
|
||||
Images []ImageData `json:"image_data"`
|
||||
Grammar string `json:"grammar"`
|
||||
CachePrompt bool `json:"cache_prompt"`
|
||||
ReturnLogits bool `json:"return_logits,omitempty"` // defaults to false
|
||||
|
||||
Options
|
||||
}
|
||||
@ -588,8 +618,10 @@ type Timings struct {
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Content string `json:"content"`
|
||||
Stop bool `json:"stop"`
|
||||
Content string `json:"content"`
|
||||
Logits []float32 `json:"logits,omitempty"`
|
||||
Tokens []string `json:"tokens,omitempty"`
|
||||
Stop bool `json:"stop"`
|
||||
|
||||
Model string `json:"model,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
@ -637,12 +669,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
samplingParams.Seed = uint32(req.Seed)
|
||||
samplingParams.Grammar = req.Grammar
|
||||
|
||||
slog.Info("completion request", "return_logits", req.ReturnLogits)
|
||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||
numPredict: req.NumPredict,
|
||||
stop: req.Stop,
|
||||
numKeep: req.NumKeep,
|
||||
samplingParams: &samplingParams,
|
||||
embedding: false,
|
||||
returnLogits: req.ReturnLogits,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||
@ -692,9 +726,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
case content, ok := <-seq.responses:
|
||||
if ok {
|
||||
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
||||
Content: content,
|
||||
}); err != nil {
|
||||
// slog.Info("content", "content", content.Content)
|
||||
if err := json.NewEncoder(w).Encode(&content); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
close(seq.quit)
|
||||
return
|
||||
@ -1003,3 +1036,76 @@ func Execute(args []string) error {
|
||||
cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
// // Helper function to get top K logits and convert to log probabilities
|
||||
// func getTopLogits(logits []float32, k int, model *llama.Model) []api.LogProbs {
|
||||
// if k <= 0 {
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// // Convert logits to probabilities using softmax
|
||||
// probs := softmax(logits)
|
||||
|
||||
// // Create slice of index/probability pairs
|
||||
// pairs := make([]struct {
|
||||
// token int
|
||||
// prob float32
|
||||
// }, len(probs))
|
||||
|
||||
// for i, p := range probs {
|
||||
// pairs[i] = struct {
|
||||
// token int
|
||||
// prob float32
|
||||
// }{i, p}
|
||||
// }
|
||||
|
||||
// // Sort by probability (descending)
|
||||
// sort.Slice(pairs, func(i, j int) bool {
|
||||
// return pairs[i].prob > pairs[j].prob
|
||||
// })
|
||||
|
||||
// // Take top K
|
||||
// k = min(k, len(pairs))
|
||||
// result := make([]api.LogProbs, k)
|
||||
|
||||
// for i := 0; i < k; i++ {
|
||||
// result[i] = api.LogProbs{
|
||||
// TopLogprobs: []api.TokenLogprob{
|
||||
// {
|
||||
// Token: model.TokenToPiece(pairs[i].token),
|
||||
// Logprob: float32(math.Log(float64(pairs[i].prob))),
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
// }
|
||||
|
||||
// return result
|
||||
// }
|
||||
|
||||
// Helper function to compute softmax
|
||||
func softmax(logits []float32) []float32 {
|
||||
probs := make([]float32, len(logits))
|
||||
|
||||
// Find max for numerical stability
|
||||
max := float32(math.Inf(-1))
|
||||
for _, l := range logits {
|
||||
if l > max {
|
||||
max = l
|
||||
}
|
||||
}
|
||||
|
||||
// Compute exp(x - max) and sum
|
||||
sum := float32(0)
|
||||
for i, l := range logits {
|
||||
ex := float32(math.Exp(float64(l - max)))
|
||||
probs[i] = ex
|
||||
sum += ex
|
||||
}
|
||||
|
||||
// Normalize
|
||||
for i := range probs {
|
||||
probs[i] /= sum
|
||||
}
|
||||
|
||||
return probs
|
||||
}
|
||||
|
@ -633,7 +633,8 @@ number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||
ws ::= ([ \t\n] ws)?
|
||||
`
|
||||
|
||||
const maxBufferSize = 512 * format.KiloByte
|
||||
// TODO: change back to 512 * format.KiloByte
|
||||
const maxBufferSize = 2048 * format.KiloByte
|
||||
|
||||
type ImageData struct {
|
||||
Data []byte `json:"data"`
|
||||
@ -642,11 +643,12 @@ type ImageData struct {
|
||||
}
|
||||
|
||||
type completion struct {
|
||||
Content string `json:"content"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Stop bool `json:"stop"`
|
||||
StoppedLimit bool `json:"stopped_limit"`
|
||||
Content string `json:"content"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Stop bool `json:"stop"`
|
||||
StoppedLimit bool `json:"stopped_limit"`
|
||||
Logits []float32 `json:"logits,omitempty"`
|
||||
|
||||
Timings struct {
|
||||
PredictedN int `json:"predicted_n"`
|
||||
@ -657,10 +659,11 @@ type completion struct {
|
||||
}
|
||||
|
||||
type CompletionRequest struct {
|
||||
Prompt string
|
||||
Format json.RawMessage
|
||||
Images []ImageData
|
||||
Options *api.Options
|
||||
Prompt string
|
||||
Format json.RawMessage
|
||||
Images []ImageData
|
||||
Options *api.Options
|
||||
ReturnLogits bool
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
@ -671,6 +674,7 @@ type CompletionResponse struct {
|
||||
PromptEvalDuration time.Duration
|
||||
EvalCount int
|
||||
EvalDuration time.Duration
|
||||
Logits []float32
|
||||
}
|
||||
|
||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||
@ -696,6 +700,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
"seed": req.Options.Seed,
|
||||
"stop": req.Options.Stop,
|
||||
"image_data": req.Images,
|
||||
"return_logits": req.ReturnLogits,
|
||||
"cache_prompt": true,
|
||||
}
|
||||
|
||||
@ -821,6 +826,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
if c.Content != "" {
|
||||
fn(CompletionResponse{
|
||||
Content: c.Content,
|
||||
Logits: c.Logits,
|
||||
})
|
||||
}
|
||||
|
||||
@ -837,6 +843,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
|
||||
EvalCount: c.Timings.PredictedN,
|
||||
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
|
||||
Logits: c.Logits,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
@ -295,10 +296,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
var sb strings.Builder
|
||||
defer close(ch)
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
ReturnLogits: false,
|
||||
}, func(cr llm.CompletionResponse) {
|
||||
res := api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
@ -312,6 +314,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
EvalCount: cr.EvalCount,
|
||||
EvalDuration: cr.EvalDuration,
|
||||
},
|
||||
Logits: cr.Logits,
|
||||
}
|
||||
|
||||
if _, err := sb.WriteString(cr.Content); err != nil {
|
||||
@ -1547,26 +1550,32 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
var sb strings.Builder
|
||||
var toolCallIndex int = 0
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
}, func(r llm.CompletionResponse) {
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
ReturnLogits: true,
|
||||
}, func(cr llm.CompletionResponse) {
|
||||
res := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
||||
Done: r.Done,
|
||||
DoneReason: r.DoneReason,
|
||||
Message: api.Message{Role: "assistant", Content: cr.Content},
|
||||
Done: cr.Done,
|
||||
DoneReason: cr.DoneReason,
|
||||
Logits: []float32{},
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: r.PromptEvalCount,
|
||||
PromptEvalDuration: r.PromptEvalDuration,
|
||||
EvalCount: r.EvalCount,
|
||||
EvalDuration: r.EvalDuration,
|
||||
PromptEvalCount: cr.PromptEvalCount,
|
||||
PromptEvalDuration: cr.PromptEvalDuration,
|
||||
EvalCount: cr.EvalCount,
|
||||
EvalDuration: cr.EvalDuration,
|
||||
},
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
topK := int(3)
|
||||
logits := make([]float32, len(cr.Logits))
|
||||
copy(logits, cr.Logits)
|
||||
res.TopLogprobs = getTopKLogProbs(c.Request.Context(), r, logits, topK)
|
||||
if cr.Done {
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
@ -1582,7 +1591,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
// Streaming tool calls:
|
||||
// If tools are recognized, use a flag to track the sending of a tool downstream
|
||||
// This ensures that content is cleared from the message on the last chunk sent
|
||||
sb.WriteString(r.Content)
|
||||
sb.WriteString(cr.Content)
|
||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||
res.Message.ToolCalls = toolCalls
|
||||
for i := range toolCalls {
|
||||
@ -1595,7 +1604,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
if cr.Done {
|
||||
// Send any remaining content if no tool calls were detected
|
||||
if toolCallIndex == 0 {
|
||||
res.Message.Content = sb.String()
|
||||
@ -1645,6 +1654,48 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
func getTopKLogProbs(ctx context.Context, s llm.LlamaServer, logits []float32, topK int) []api.TokenLogprob {
|
||||
// Calculate softmax denominator first (log sum exp trick for numerical stability)
|
||||
maxLogit := float32(math.Inf(-1))
|
||||
for _, logit := range logits {
|
||||
if logit > maxLogit {
|
||||
maxLogit = logit
|
||||
}
|
||||
}
|
||||
|
||||
var sumExp float32
|
||||
for _, logit := range logits {
|
||||
sumExp += float32(math.Exp(float64(logit - maxLogit)))
|
||||
}
|
||||
logSumExp := float32(math.Log(float64(sumExp))) + maxLogit
|
||||
|
||||
// Calculate log probs and track top K
|
||||
logProbs := make([]api.TokenLogprob, len(logits))
|
||||
for i, logit := range logits {
|
||||
text, err := s.Detokenize(ctx, []int{i})
|
||||
if err != nil {
|
||||
slog.Error("detokenize error for logprob", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
logProbs[i] = api.TokenLogprob{
|
||||
Text: text,
|
||||
Logprob: logit - logSumExp,
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by logprob descending and take top K
|
||||
sort.Slice(logProbs, func(i, j int) bool {
|
||||
return logProbs[i].Logprob > logProbs[j].Logprob
|
||||
})
|
||||
|
||||
if len(logProbs) > topK {
|
||||
logProbs = logProbs[:topK]
|
||||
}
|
||||
|
||||
return logProbs
|
||||
}
|
||||
|
||||
func handleScheduleError(c *gin.Context, name string, err error) {
|
||||
switch {
|
||||
case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
|
||||
|
Loading…
x
Reference in New Issue
Block a user