WIP but got logits n stuff
This commit is contained in:
parent
d7e7e6a01e
commit
c92d418a7c
16
api/types.go
16
api/types.go
@ -80,6 +80,8 @@ type GenerateRequest struct {
|
|||||||
// Options lists model-specific options. For example, temperature can be
|
// Options lists model-specific options. For example, temperature can be
|
||||||
// set through this field, if the model supports it.
|
// set through this field, if the model supports it.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
|
|
||||||
|
ReturnLogits bool `json:"return_logits,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatRequest describes a request sent by [Client.Chat].
|
// ChatRequest describes a request sent by [Client.Chat].
|
||||||
@ -105,6 +107,8 @@ type ChatRequest struct {
|
|||||||
|
|
||||||
// Options lists model-specific options.
|
// Options lists model-specific options.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
|
|
||||||
|
ReturnLogits bool `json:"return_logits,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Tools []Tool
|
type Tools []Tool
|
||||||
@ -189,6 +193,7 @@ type ChatResponse struct {
|
|||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
Message Message `json:"message"`
|
Message Message `json:"message"`
|
||||||
DoneReason string `json:"done_reason,omitempty"`
|
DoneReason string `json:"done_reason,omitempty"`
|
||||||
|
Logits []float32 `json:"logits"`
|
||||||
|
|
||||||
Done bool `json:"done"`
|
Done bool `json:"done"`
|
||||||
|
|
||||||
@ -204,6 +209,15 @@ type Metrics struct {
|
|||||||
EvalDuration time.Duration `json:"eval_duration,omitempty"`
|
EvalDuration time.Duration `json:"eval_duration,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TokenLogprob struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
Logprob float32 `json:"logprob"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type LogProbs struct {
|
||||||
|
TopLogprobs []TokenLogprob `json:"top_logprobs"`
|
||||||
|
}
|
||||||
|
|
||||||
// Options specified in [GenerateRequest]. If you add a new option here, also
|
// Options specified in [GenerateRequest]. If you add a new option here, also
|
||||||
// add it to the API docs.
|
// add it to the API docs.
|
||||||
type Options struct {
|
type Options struct {
|
||||||
@ -450,6 +464,8 @@ type GenerateResponse struct {
|
|||||||
Context []int `json:"context,omitempty"`
|
Context []int `json:"context,omitempty"`
|
||||||
|
|
||||||
Metrics
|
Metrics
|
||||||
|
|
||||||
|
Logits []float32 `json:"logits"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModelDetails provides details about a model.
|
// ModelDetails provides details about a model.
|
||||||
|
@ -260,6 +260,19 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 {
|
|||||||
return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
|
return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetLogits returns the logits from the last decode operation.
|
||||||
|
// The returned slice has length equal to the vocabulary size.
|
||||||
|
func (c *Context) GetLogits() []float32 {
|
||||||
|
logits := unsafe.Pointer(C.llama_get_logits(c.c))
|
||||||
|
if logits == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the number of vocabulary tokens to determine array size
|
||||||
|
vocabSize := c.Model().NumVocab()
|
||||||
|
return unsafe.Slice((*float32)(logits), vocabSize)
|
||||||
|
}
|
||||||
|
|
||||||
type ModelParams struct {
|
type ModelParams struct {
|
||||||
NumGpuLayers int
|
NumGpuLayers int
|
||||||
MainGpu int
|
MainGpu int
|
||||||
@ -737,14 +750,3 @@ func SchemaToGrammar(schema []byte) []byte {
|
|||||||
}
|
}
|
||||||
return buf[:n]
|
return buf[:n]
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLogits returns the logits from the last decode operation.
|
|
||||||
// The returned slice has length equal to the vocabulary size.
|
|
||||||
func (c *Context) GetLogits() []float32 {
|
|
||||||
logits := unsafe.Pointer(C.llama_get_logits(c.c))
|
|
||||||
if logits == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the number of vocabulary tokens to determine array size
|
|
||||||
vocabSize := c.Model().NumVocab()
|
|
||||||
|
@ -8,12 +8,14 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"regexp"
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -59,7 +61,7 @@ type Sequence struct {
|
|||||||
crossAttention bool
|
crossAttention bool
|
||||||
|
|
||||||
// channel to send responses over
|
// channel to send responses over
|
||||||
responses chan string
|
responses chan CompletionResponse
|
||||||
|
|
||||||
// channel to stop decoding (such as if the remote connection is closed)
|
// channel to stop decoding (such as if the remote connection is closed)
|
||||||
quit chan bool
|
quit chan bool
|
||||||
@ -88,6 +90,15 @@ type Sequence struct {
|
|||||||
startGenerationTime time.Time
|
startGenerationTime time.Time
|
||||||
numDecoded int
|
numDecoded int
|
||||||
numPromptInputs int
|
numPromptInputs int
|
||||||
|
|
||||||
|
// New flag we need to add to Sequence struct
|
||||||
|
returnLogits bool
|
||||||
|
|
||||||
|
// Using our new GetLogits() method
|
||||||
|
logits []float32
|
||||||
|
|
||||||
|
// Add new channel for logits
|
||||||
|
logitsOut chan []float32
|
||||||
}
|
}
|
||||||
|
|
||||||
type NewSequenceParams struct {
|
type NewSequenceParams struct {
|
||||||
@ -96,6 +107,7 @@ type NewSequenceParams struct {
|
|||||||
numKeep int
|
numKeep int
|
||||||
samplingParams *llama.SamplingParams
|
samplingParams *llama.SamplingParams
|
||||||
embedding bool
|
embedding bool
|
||||||
|
returnLogits bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
|
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||||
@ -149,13 +161,15 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
|||||||
startProcessingTime: startTime,
|
startProcessingTime: startTime,
|
||||||
numPredict: params.numPredict,
|
numPredict: params.numPredict,
|
||||||
pendingResponses: make([]string, 0),
|
pendingResponses: make([]string, 0),
|
||||||
responses: make(chan string, 100),
|
responses: make(chan CompletionResponse, 100),
|
||||||
quit: make(chan bool, 1),
|
quit: make(chan bool, 1),
|
||||||
embedding: make(chan []float32, 1),
|
embedding: make(chan []float32, 1),
|
||||||
samplingCtx: sc,
|
samplingCtx: sc,
|
||||||
embeddingOnly: params.embedding,
|
embeddingOnly: params.embedding,
|
||||||
stop: params.stop,
|
stop: params.stop,
|
||||||
numKeep: params.numKeep,
|
numKeep: params.numKeep,
|
||||||
|
returnLogits: params.returnLogits,
|
||||||
|
logitsOut: make(chan []float32, 100),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -274,25 +288,36 @@ func (s *Server) allNil() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func flushPending(seq *Sequence) bool {
|
func flushPending(seq *Sequence) bool {
|
||||||
joined := strings.Join(seq.pendingResponses, "")
|
if len(seq.pendingResponses) == 0 {
|
||||||
seq.pendingResponses = []string{}
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
content := strings.Join(seq.pendingResponses, "")
|
||||||
// Check if there are any partial UTF-8 characters remaining.
|
// Check if there are any partial UTF-8 characters remaining.
|
||||||
// We already check and queue as we are generating but some may
|
// We already check and queue as we are generating but some may
|
||||||
// still make it here:
|
// still make it here:
|
||||||
// - Sequence is ending, e.g. generation limit has been hit
|
// - Sequence is ending, e.g. generation limit has been hit
|
||||||
// - Invalid characters in the middle of a string
|
// - Invalid characters in the middle of a string
|
||||||
// This is a stricter check to ensure we never output invalid Unicode.
|
// This is a stricter check to ensure we never output invalid Unicode.
|
||||||
for !utf8.ValidString(joined) {
|
for !utf8.ValidString(content) {
|
||||||
joined = joined[:len(joined)-1]
|
content = content[:len(content)-1]
|
||||||
|
}
|
||||||
|
seq.pendingResponses = nil
|
||||||
|
|
||||||
|
resp := CompletionResponse{
|
||||||
|
Content: content,
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(joined) == 0 {
|
// Add logits if requested and available
|
||||||
return true
|
if seq.returnLogits && seq.logits != nil {
|
||||||
|
slog.Info("returning logits - flushPending")
|
||||||
|
resp.Logits = seq.logits
|
||||||
|
seq.logits = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
slog.Info("returning logits - flushPending", "logits", resp.Logits[0])
|
||||||
select {
|
select {
|
||||||
case seq.responses <- joined:
|
case seq.responses <- resp:
|
||||||
return true
|
return true
|
||||||
case <-seq.quit:
|
case <-seq.quit:
|
||||||
return false
|
return false
|
||||||
@ -476,7 +501,14 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// sample a token
|
// Before sampling:
|
||||||
|
if seq.returnLogits { // New flag we need to add to Sequence struct
|
||||||
|
slog.Info("returning logits")
|
||||||
|
seq.logits = s.lc.GetLogits() // Using our new GetLogits() method
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then sample token
|
||||||
token := seq.samplingCtx.Sample(s.lc, seq.iBatch)
|
token := seq.samplingCtx.Sample(s.lc, seq.iBatch)
|
||||||
seq.samplingCtx.Accept(token, true)
|
seq.samplingCtx.Accept(token, true)
|
||||||
piece := s.model.TokenToPiece(token)
|
piece := s.model.TokenToPiece(token)
|
||||||
@ -572,10 +604,11 @@ type ImageData struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CompletionRequest struct {
|
type CompletionRequest struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
Images []ImageData `json:"image_data"`
|
Images []ImageData `json:"image_data"`
|
||||||
Grammar string `json:"grammar"`
|
Grammar string `json:"grammar"`
|
||||||
CachePrompt bool `json:"cache_prompt"`
|
CachePrompt bool `json:"cache_prompt"`
|
||||||
|
ReturnLogits bool `json:"return_logits"`
|
||||||
|
|
||||||
Options
|
Options
|
||||||
}
|
}
|
||||||
@ -588,8 +621,10 @@ type Timings struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CompletionResponse struct {
|
type CompletionResponse struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
Stop bool `json:"stop"`
|
Logits []float32 `json:"logits,omitempty"`
|
||||||
|
Tokens []string `json:"tokens,omitempty"`
|
||||||
|
Stop bool `json:"stop"`
|
||||||
|
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
Prompt string `json:"prompt,omitempty"`
|
Prompt string `json:"prompt,omitempty"`
|
||||||
@ -637,12 +672,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
samplingParams.Seed = uint32(req.Seed)
|
samplingParams.Seed = uint32(req.Seed)
|
||||||
samplingParams.Grammar = req.Grammar
|
samplingParams.Grammar = req.Grammar
|
||||||
|
|
||||||
|
slog.Info("completion request", "return_logits", req.ReturnLogits)
|
||||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||||
numPredict: req.NumPredict,
|
numPredict: req.NumPredict,
|
||||||
stop: req.Stop,
|
stop: req.Stop,
|
||||||
numKeep: req.NumKeep,
|
numKeep: req.NumKeep,
|
||||||
samplingParams: &samplingParams,
|
samplingParams: &samplingParams,
|
||||||
embedding: false,
|
embedding: false,
|
||||||
|
returnLogits: req.ReturnLogits,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||||
@ -691,10 +728,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
close(seq.quit)
|
close(seq.quit)
|
||||||
return
|
return
|
||||||
case content, ok := <-seq.responses:
|
case content, ok := <-seq.responses:
|
||||||
|
slog.Info("logits in last chan", "content", content.Logits[0])
|
||||||
if ok {
|
if ok {
|
||||||
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
slog.Info("content", "content", content.Content)
|
||||||
Content: content,
|
if err := json.NewEncoder(w).Encode(&content); err != nil {
|
||||||
}); err != nil {
|
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
close(seq.quit)
|
close(seq.quit)
|
||||||
return
|
return
|
||||||
|
@ -642,11 +642,12 @@ type ImageData struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type completion struct {
|
type completion struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
Stop bool `json:"stop"`
|
Stop bool `json:"stop"`
|
||||||
StoppedLimit bool `json:"stopped_limit"`
|
StoppedLimit bool `json:"stopped_limit"`
|
||||||
|
Logits []float32 `json:"logits,omitempty"`
|
||||||
|
|
||||||
Timings struct {
|
Timings struct {
|
||||||
PredictedN int `json:"predicted_n"`
|
PredictedN int `json:"predicted_n"`
|
||||||
@ -657,10 +658,11 @@ type completion struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CompletionRequest struct {
|
type CompletionRequest struct {
|
||||||
Prompt string
|
Prompt string
|
||||||
Format json.RawMessage
|
Format json.RawMessage
|
||||||
Images []ImageData
|
Images []ImageData
|
||||||
Options *api.Options
|
Options *api.Options
|
||||||
|
ReturnLogits bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type CompletionResponse struct {
|
type CompletionResponse struct {
|
||||||
@ -671,6 +673,7 @@ type CompletionResponse struct {
|
|||||||
PromptEvalDuration time.Duration
|
PromptEvalDuration time.Duration
|
||||||
EvalCount int
|
EvalCount int
|
||||||
EvalDuration time.Duration
|
EvalDuration time.Duration
|
||||||
|
Logits []float32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||||
@ -696,6 +699,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
"seed": req.Options.Seed,
|
"seed": req.Options.Seed,
|
||||||
"stop": req.Options.Stop,
|
"stop": req.Options.Stop,
|
||||||
"image_data": req.Images,
|
"image_data": req.Images,
|
||||||
|
"return_logits": req.ReturnLogits,
|
||||||
"cache_prompt": true,
|
"cache_prompt": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -821,6 +825,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
if c.Content != "" {
|
if c.Content != "" {
|
||||||
fn(CompletionResponse{
|
fn(CompletionResponse{
|
||||||
Content: c.Content,
|
Content: c.Content,
|
||||||
|
Logits: c.Logits,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -837,6 +842,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
|
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
|
||||||
EvalCount: c.Timings.PredictedN,
|
EvalCount: c.Timings.PredictedN,
|
||||||
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
|
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
|
||||||
|
Logits: c.Logits,
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -295,10 +295,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Images: images,
|
Images: images,
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
Options: opts,
|
Options: opts,
|
||||||
|
ReturnLogits: req.ReturnLogits,
|
||||||
}, func(cr llm.CompletionResponse) {
|
}, func(cr llm.CompletionResponse) {
|
||||||
res := api.GenerateResponse{
|
res := api.GenerateResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
@ -312,6 +313,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
EvalCount: cr.EvalCount,
|
EvalCount: cr.EvalCount,
|
||||||
EvalDuration: cr.EvalDuration,
|
EvalDuration: cr.EvalDuration,
|
||||||
},
|
},
|
||||||
|
Logits: cr.Logits,
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := sb.WriteString(cr.Content); err != nil {
|
if _, err := sb.WriteString(cr.Content); err != nil {
|
||||||
@ -1541,16 +1543,19 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
|
|
||||||
slog.Debug("chat request", "images", len(images), "prompt", prompt)
|
slog.Debug("chat request", "images", len(images), "prompt", prompt)
|
||||||
|
|
||||||
|
slog.Info("chat request", "return_logits", req.ReturnLogits)
|
||||||
|
|
||||||
ch := make(chan any)
|
ch := make(chan any)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
var toolCallIndex int = 0
|
var toolCallIndex int = 0
|
||||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Images: images,
|
Images: images,
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
Options: opts,
|
Options: opts,
|
||||||
|
ReturnLogits: true,
|
||||||
}, func(r llm.CompletionResponse) {
|
}, func(r llm.CompletionResponse) {
|
||||||
res := api.ChatResponse{
|
res := api.ChatResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
@ -1558,6 +1563,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
Message: api.Message{Role: "assistant", Content: r.Content},
|
||||||
Done: r.Done,
|
Done: r.Done,
|
||||||
DoneReason: r.DoneReason,
|
DoneReason: r.DoneReason,
|
||||||
|
Logits: r.Logits,
|
||||||
Metrics: api.Metrics{
|
Metrics: api.Metrics{
|
||||||
PromptEvalCount: r.PromptEvalCount,
|
PromptEvalCount: r.PromptEvalCount,
|
||||||
PromptEvalDuration: r.PromptEvalDuration,
|
PromptEvalDuration: r.PromptEvalDuration,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user