diff --git a/api/types.go b/api/types.go index 0ea0b9bf0..c9497a40e 100644 --- a/api/types.go +++ b/api/types.go @@ -80,6 +80,8 @@ type GenerateRequest struct { // Options lists model-specific options. For example, temperature can be // set through this field, if the model supports it. Options map[string]interface{} `json:"options"` + + ReturnLogits bool `json:"return_logits,omitempty"` } // ChatRequest describes a request sent by [Client.Chat]. @@ -105,6 +107,8 @@ type ChatRequest struct { // Options lists model-specific options. Options map[string]interface{} `json:"options"` + + ReturnLogits bool `json:"return_logits,omitempty"` } type Tools []Tool @@ -189,6 +193,7 @@ type ChatResponse struct { CreatedAt time.Time `json:"created_at"` Message Message `json:"message"` DoneReason string `json:"done_reason,omitempty"` + Logits []float32 `json:"logits"` Done bool `json:"done"` @@ -204,6 +209,15 @@ type Metrics struct { EvalDuration time.Duration `json:"eval_duration,omitempty"` } +type TokenLogprob struct { + Token string `json:"token"` + Logprob float32 `json:"logprob"` +} + +type LogProbs struct { + TopLogprobs []TokenLogprob `json:"top_logprobs"` +} + // Options specified in [GenerateRequest]. If you add a new option here, also // add it to the API docs. type Options struct { @@ -450,6 +464,8 @@ type GenerateResponse struct { Context []int `json:"context,omitempty"` Metrics + + Logits []float32 `json:"logits"` } // ModelDetails provides details about a model. diff --git a/llama/llama.go b/llama/llama.go index acbf0a672..25bc67086 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -260,6 +260,19 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 { return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd()) } +// GetLogits returns the logits from the last decode operation. +// The returned slice has length equal to the vocabulary size. +func (c *Context) GetLogits() []float32 { + logits := unsafe.Pointer(C.llama_get_logits(c.c)) + if logits == nil { + return nil + } + + // Get the number of vocabulary tokens to determine array size + vocabSize := c.Model().NumVocab() + return unsafe.Slice((*float32)(logits), vocabSize) +} + type ModelParams struct { NumGpuLayers int MainGpu int @@ -737,14 +750,3 @@ func SchemaToGrammar(schema []byte) []byte { } return buf[:n] } - -// GetLogits returns the logits from the last decode operation. -// The returned slice has length equal to the vocabulary size. -func (c *Context) GetLogits() []float32 { - logits := unsafe.Pointer(C.llama_get_logits(c.c)) - if logits == nil { - return nil - } - - // Get the number of vocabulary tokens to determine array size - vocabSize := c.Model().NumVocab() diff --git a/llama/runner/runner.go b/llama/runner/runner.go index e7ed3ba6d..1efb2ba98 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -8,12 +8,14 @@ import ( "fmt" "log" "log/slog" + "math" "net" "net/http" "os" "path/filepath" "regexp" "runtime" + "sort" "strconv" "strings" "sync" @@ -59,7 +61,7 @@ type Sequence struct { crossAttention bool // channel to send responses over - responses chan string + responses chan CompletionResponse // channel to stop decoding (such as if the remote connection is closed) quit chan bool @@ -88,6 +90,15 @@ type Sequence struct { startGenerationTime time.Time numDecoded int numPromptInputs int + + // New flag we need to add to Sequence struct + returnLogits bool + + // Using our new GetLogits() method + logits []float32 + + // Add new channel for logits + logitsOut chan []float32 } type NewSequenceParams struct { @@ -96,6 +107,7 @@ type NewSequenceParams struct { numKeep int samplingParams *llama.SamplingParams embedding bool + returnLogits bool } func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) { @@ -149,13 +161,15 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen startProcessingTime: startTime, numPredict: params.numPredict, pendingResponses: make([]string, 0), - responses: make(chan string, 100), + responses: make(chan CompletionResponse, 100), quit: make(chan bool, 1), embedding: make(chan []float32, 1), samplingCtx: sc, embeddingOnly: params.embedding, stop: params.stop, numKeep: params.numKeep, + returnLogits: params.returnLogits, + logitsOut: make(chan []float32, 100), }, nil } @@ -274,25 +288,36 @@ func (s *Server) allNil() bool { } func flushPending(seq *Sequence) bool { - joined := strings.Join(seq.pendingResponses, "") - seq.pendingResponses = []string{} + if len(seq.pendingResponses) == 0 { + return true + } + content := strings.Join(seq.pendingResponses, "") // Check if there are any partial UTF-8 characters remaining. // We already check and queue as we are generating but some may // still make it here: // - Sequence is ending, e.g. generation limit has been hit // - Invalid characters in the middle of a string // This is a stricter check to ensure we never output invalid Unicode. - for !utf8.ValidString(joined) { - joined = joined[:len(joined)-1] + for !utf8.ValidString(content) { + content = content[:len(content)-1] + } + seq.pendingResponses = nil + + resp := CompletionResponse{ + Content: content, } - if len(joined) == 0 { - return true + // Add logits if requested and available + if seq.returnLogits && seq.logits != nil { + slog.Info("returning logits - flushPending") + resp.Logits = seq.logits + seq.logits = nil } + slog.Info("returning logits - flushPending", "logits", resp.Logits[0]) select { - case seq.responses <- joined: + case seq.responses <- resp: return true case <-seq.quit: return false @@ -476,7 +501,14 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) continue } - // sample a token + // Before sampling: + if seq.returnLogits { // New flag we need to add to Sequence struct + slog.Info("returning logits") + seq.logits = s.lc.GetLogits() // Using our new GetLogits() method + + } + + // Then sample token token := seq.samplingCtx.Sample(s.lc, seq.iBatch) seq.samplingCtx.Accept(token, true) piece := s.model.TokenToPiece(token) @@ -572,10 +604,11 @@ type ImageData struct { } type CompletionRequest struct { - Prompt string `json:"prompt"` - Images []ImageData `json:"image_data"` - Grammar string `json:"grammar"` - CachePrompt bool `json:"cache_prompt"` + Prompt string `json:"prompt"` + Images []ImageData `json:"image_data"` + Grammar string `json:"grammar"` + CachePrompt bool `json:"cache_prompt"` + ReturnLogits bool `json:"return_logits"` Options } @@ -588,8 +621,10 @@ type Timings struct { } type CompletionResponse struct { - Content string `json:"content"` - Stop bool `json:"stop"` + Content string `json:"content"` + Logits []float32 `json:"logits,omitempty"` + Tokens []string `json:"tokens,omitempty"` + Stop bool `json:"stop"` Model string `json:"model,omitempty"` Prompt string `json:"prompt,omitempty"` @@ -637,12 +672,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { samplingParams.Seed = uint32(req.Seed) samplingParams.Grammar = req.Grammar + slog.Info("completion request", "return_logits", req.ReturnLogits) seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ numPredict: req.NumPredict, stop: req.Stop, numKeep: req.NumKeep, samplingParams: &samplingParams, embedding: false, + returnLogits: req.ReturnLogits, }) if err != nil { http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) @@ -691,10 +728,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { close(seq.quit) return case content, ok := <-seq.responses: + slog.Info("logits in last chan", "content", content.Logits[0]) if ok { - if err := json.NewEncoder(w).Encode(&CompletionResponse{ - Content: content, - }); err != nil { + slog.Info("content", "content", content.Content) + if err := json.NewEncoder(w).Encode(&content); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) close(seq.quit) return diff --git a/llm/server.go b/llm/server.go index bb9062adc..7ee9d9f63 100644 --- a/llm/server.go +++ b/llm/server.go @@ -642,11 +642,12 @@ type ImageData struct { } type completion struct { - Content string `json:"content"` - Model string `json:"model"` - Prompt string `json:"prompt"` - Stop bool `json:"stop"` - StoppedLimit bool `json:"stopped_limit"` + Content string `json:"content"` + Model string `json:"model"` + Prompt string `json:"prompt"` + Stop bool `json:"stop"` + StoppedLimit bool `json:"stopped_limit"` + Logits []float32 `json:"logits,omitempty"` Timings struct { PredictedN int `json:"predicted_n"` @@ -657,10 +658,11 @@ type completion struct { } type CompletionRequest struct { - Prompt string - Format json.RawMessage - Images []ImageData - Options *api.Options + Prompt string + Format json.RawMessage + Images []ImageData + Options *api.Options + ReturnLogits bool } type CompletionResponse struct { @@ -671,6 +673,7 @@ type CompletionResponse struct { PromptEvalDuration time.Duration EvalCount int EvalDuration time.Duration + Logits []float32 } func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { @@ -696,6 +699,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu "seed": req.Options.Seed, "stop": req.Options.Stop, "image_data": req.Images, + "return_logits": req.ReturnLogits, "cache_prompt": true, } @@ -821,6 +825,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu if c.Content != "" { fn(CompletionResponse{ Content: c.Content, + Logits: c.Logits, }) } @@ -837,6 +842,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu PromptEvalDuration: parseDurationMs(c.Timings.PromptMS), EvalCount: c.Timings.PredictedN, EvalDuration: parseDurationMs(c.Timings.PredictedMS), + Logits: c.Logits, }) return nil } diff --git a/server/routes.go b/server/routes.go index f3b78927c..8e5d9d8bd 100644 --- a/server/routes.go +++ b/server/routes.go @@ -295,10 +295,11 @@ func (s *Server) GenerateHandler(c *gin.Context) { var sb strings.Builder defer close(ch) if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ - Prompt: prompt, - Images: images, - Format: req.Format, - Options: opts, + Prompt: prompt, + Images: images, + Format: req.Format, + Options: opts, + ReturnLogits: req.ReturnLogits, }, func(cr llm.CompletionResponse) { res := api.GenerateResponse{ Model: req.Model, @@ -312,6 +313,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { EvalCount: cr.EvalCount, EvalDuration: cr.EvalDuration, }, + Logits: cr.Logits, } if _, err := sb.WriteString(cr.Content); err != nil { @@ -1541,16 +1543,19 @@ func (s *Server) ChatHandler(c *gin.Context) { slog.Debug("chat request", "images", len(images), "prompt", prompt) + slog.Info("chat request", "return_logits", req.ReturnLogits) + ch := make(chan any) go func() { defer close(ch) var sb strings.Builder var toolCallIndex int = 0 if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ - Prompt: prompt, - Images: images, - Format: req.Format, - Options: opts, + Prompt: prompt, + Images: images, + Format: req.Format, + Options: opts, + ReturnLogits: true, }, func(r llm.CompletionResponse) { res := api.ChatResponse{ Model: req.Model, @@ -1558,6 +1563,7 @@ func (s *Server) ChatHandler(c *gin.Context) { Message: api.Message{Role: "assistant", Content: r.Content}, Done: r.Done, DoneReason: r.DoneReason, + Logits: r.Logits, Metrics: api.Metrics{ PromptEvalCount: r.PromptEvalCount, PromptEvalDuration: r.PromptEvalDuration,