WIP but got logits n stuff

This commit is contained in:
ParthSareen 2024-12-13 16:20:44 -08:00
parent d7e7e6a01e
commit c92d418a7c
5 changed files with 114 additions and 47 deletions

View File

@ -80,6 +80,8 @@ type GenerateRequest struct {
// Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it.
Options map[string]interface{} `json:"options"`
ReturnLogits bool `json:"return_logits,omitempty"`
}
// ChatRequest describes a request sent by [Client.Chat].
@ -105,6 +107,8 @@ type ChatRequest struct {
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
ReturnLogits bool `json:"return_logits,omitempty"`
}
type Tools []Tool
@ -189,6 +193,7 @@ type ChatResponse struct {
CreatedAt time.Time `json:"created_at"`
Message Message `json:"message"`
DoneReason string `json:"done_reason,omitempty"`
Logits []float32 `json:"logits"`
Done bool `json:"done"`
@ -204,6 +209,15 @@ type Metrics struct {
EvalDuration time.Duration `json:"eval_duration,omitempty"`
}
type TokenLogprob struct {
Token string `json:"token"`
Logprob float32 `json:"logprob"`
}
type LogProbs struct {
TopLogprobs []TokenLogprob `json:"top_logprobs"`
}
// Options specified in [GenerateRequest]. If you add a new option here, also
// add it to the API docs.
type Options struct {
@ -450,6 +464,8 @@ type GenerateResponse struct {
Context []int `json:"context,omitempty"`
Metrics
Logits []float32 `json:"logits"`
}
// ModelDetails provides details about a model.

View File

@ -260,6 +260,19 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 {
return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
}
// GetLogits returns the logits from the last decode operation.
// The returned slice has length equal to the vocabulary size.
func (c *Context) GetLogits() []float32 {
logits := unsafe.Pointer(C.llama_get_logits(c.c))
if logits == nil {
return nil
}
// Get the number of vocabulary tokens to determine array size
vocabSize := c.Model().NumVocab()
return unsafe.Slice((*float32)(logits), vocabSize)
}
type ModelParams struct {
NumGpuLayers int
MainGpu int
@ -737,14 +750,3 @@ func SchemaToGrammar(schema []byte) []byte {
}
return buf[:n]
}
// GetLogits returns the logits from the last decode operation.
// The returned slice has length equal to the vocabulary size.
func (c *Context) GetLogits() []float32 {
logits := unsafe.Pointer(C.llama_get_logits(c.c))
if logits == nil {
return nil
}
// Get the number of vocabulary tokens to determine array size
vocabSize := c.Model().NumVocab()

View File

@ -8,12 +8,14 @@ import (
"fmt"
"log"
"log/slog"
"math"
"net"
"net/http"
"os"
"path/filepath"
"regexp"
"runtime"
"sort"
"strconv"
"strings"
"sync"
@ -59,7 +61,7 @@ type Sequence struct {
crossAttention bool
// channel to send responses over
responses chan string
responses chan CompletionResponse
// channel to stop decoding (such as if the remote connection is closed)
quit chan bool
@ -88,6 +90,15 @@ type Sequence struct {
startGenerationTime time.Time
numDecoded int
numPromptInputs int
// New flag we need to add to Sequence struct
returnLogits bool
// Using our new GetLogits() method
logits []float32
// Add new channel for logits
logitsOut chan []float32
}
type NewSequenceParams struct {
@ -96,6 +107,7 @@ type NewSequenceParams struct {
numKeep int
samplingParams *llama.SamplingParams
embedding bool
returnLogits bool
}
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
@ -149,13 +161,15 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
startProcessingTime: startTime,
numPredict: params.numPredict,
pendingResponses: make([]string, 0),
responses: make(chan string, 100),
responses: make(chan CompletionResponse, 100),
quit: make(chan bool, 1),
embedding: make(chan []float32, 1),
samplingCtx: sc,
embeddingOnly: params.embedding,
stop: params.stop,
numKeep: params.numKeep,
returnLogits: params.returnLogits,
logitsOut: make(chan []float32, 100),
}, nil
}
@ -274,25 +288,36 @@ func (s *Server) allNil() bool {
}
func flushPending(seq *Sequence) bool {
joined := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = []string{}
if len(seq.pendingResponses) == 0 {
return true
}
content := strings.Join(seq.pendingResponses, "")
// Check if there are any partial UTF-8 characters remaining.
// We already check and queue as we are generating but some may
// still make it here:
// - Sequence is ending, e.g. generation limit has been hit
// - Invalid characters in the middle of a string
// This is a stricter check to ensure we never output invalid Unicode.
for !utf8.ValidString(joined) {
joined = joined[:len(joined)-1]
for !utf8.ValidString(content) {
content = content[:len(content)-1]
}
seq.pendingResponses = nil
resp := CompletionResponse{
Content: content,
}
if len(joined) == 0 {
return true
// Add logits if requested and available
if seq.returnLogits && seq.logits != nil {
slog.Info("returning logits - flushPending")
resp.Logits = seq.logits
seq.logits = nil
}
slog.Info("returning logits - flushPending", "logits", resp.Logits[0])
select {
case seq.responses <- joined:
case seq.responses <- resp:
return true
case <-seq.quit:
return false
@ -476,7 +501,14 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
continue
}
// sample a token
// Before sampling:
if seq.returnLogits { // New flag we need to add to Sequence struct
slog.Info("returning logits")
seq.logits = s.lc.GetLogits() // Using our new GetLogits() method
}
// Then sample token
token := seq.samplingCtx.Sample(s.lc, seq.iBatch)
seq.samplingCtx.Accept(token, true)
piece := s.model.TokenToPiece(token)
@ -572,10 +604,11 @@ type ImageData struct {
}
type CompletionRequest struct {
Prompt string `json:"prompt"`
Images []ImageData `json:"image_data"`
Grammar string `json:"grammar"`
CachePrompt bool `json:"cache_prompt"`
Prompt string `json:"prompt"`
Images []ImageData `json:"image_data"`
Grammar string `json:"grammar"`
CachePrompt bool `json:"cache_prompt"`
ReturnLogits bool `json:"return_logits"`
Options
}
@ -588,8 +621,10 @@ type Timings struct {
}
type CompletionResponse struct {
Content string `json:"content"`
Stop bool `json:"stop"`
Content string `json:"content"`
Logits []float32 `json:"logits,omitempty"`
Tokens []string `json:"tokens,omitempty"`
Stop bool `json:"stop"`
Model string `json:"model,omitempty"`
Prompt string `json:"prompt,omitempty"`
@ -637,12 +672,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
samplingParams.Seed = uint32(req.Seed)
samplingParams.Grammar = req.Grammar
slog.Info("completion request", "return_logits", req.ReturnLogits)
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.NumPredict,
stop: req.Stop,
numKeep: req.NumKeep,
samplingParams: &samplingParams,
embedding: false,
returnLogits: req.ReturnLogits,
})
if err != nil {
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
@ -691,10 +728,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
close(seq.quit)
return
case content, ok := <-seq.responses:
slog.Info("logits in last chan", "content", content.Logits[0])
if ok {
if err := json.NewEncoder(w).Encode(&CompletionResponse{
Content: content,
}); err != nil {
slog.Info("content", "content", content.Content)
if err := json.NewEncoder(w).Encode(&content); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
close(seq.quit)
return

View File

@ -642,11 +642,12 @@ type ImageData struct {
}
type completion struct {
Content string `json:"content"`
Model string `json:"model"`
Prompt string `json:"prompt"`
Stop bool `json:"stop"`
StoppedLimit bool `json:"stopped_limit"`
Content string `json:"content"`
Model string `json:"model"`
Prompt string `json:"prompt"`
Stop bool `json:"stop"`
StoppedLimit bool `json:"stopped_limit"`
Logits []float32 `json:"logits,omitempty"`
Timings struct {
PredictedN int `json:"predicted_n"`
@ -657,10 +658,11 @@ type completion struct {
}
type CompletionRequest struct {
Prompt string
Format json.RawMessage
Images []ImageData
Options *api.Options
Prompt string
Format json.RawMessage
Images []ImageData
Options *api.Options
ReturnLogits bool
}
type CompletionResponse struct {
@ -671,6 +673,7 @@ type CompletionResponse struct {
PromptEvalDuration time.Duration
EvalCount int
EvalDuration time.Duration
Logits []float32
}
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
@ -696,6 +699,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
"seed": req.Options.Seed,
"stop": req.Options.Stop,
"image_data": req.Images,
"return_logits": req.ReturnLogits,
"cache_prompt": true,
}
@ -821,6 +825,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
if c.Content != "" {
fn(CompletionResponse{
Content: c.Content,
Logits: c.Logits,
})
}
@ -837,6 +842,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
EvalCount: c.Timings.PredictedN,
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
Logits: c.Logits,
})
return nil
}

View File

@ -295,10 +295,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var sb strings.Builder
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
ReturnLogits: req.ReturnLogits,
}, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{
Model: req.Model,
@ -312,6 +313,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
EvalCount: cr.EvalCount,
EvalDuration: cr.EvalDuration,
},
Logits: cr.Logits,
}
if _, err := sb.WriteString(cr.Content); err != nil {
@ -1541,16 +1543,19 @@ func (s *Server) ChatHandler(c *gin.Context) {
slog.Debug("chat request", "images", len(images), "prompt", prompt)
slog.Info("chat request", "return_logits", req.ReturnLogits)
ch := make(chan any)
go func() {
defer close(ch)
var sb strings.Builder
var toolCallIndex int = 0
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
ReturnLogits: true,
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
Model: req.Model,
@ -1558,6 +1563,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done,
DoneReason: r.DoneReason,
Logits: r.Logits,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,