Compare commits

...

5 Commits

Author SHA1 Message Date
Bruce MacDonald
b88489a87e ... 2025-02-20 09:36:36 -08:00
Bruce MacDonald
fdbb0b5cfe prototype 2025-02-13 15:22:15 -08:00
Bruce MacDonald
64f95067ba ... 2025-02-13 14:02:04 -08:00
Bruce MacDonald
6dfcdec2da send completion response on chan 2025-02-12 17:03:52 -08:00
Bruce MacDonald
7d16ec8fe8 print logprobs 2025-02-12 16:36:03 -08:00
8 changed files with 300 additions and 119 deletions

View File

@ -77,6 +77,8 @@ type GenerateRequest struct {
// request, for multimodal models.
Images []ImageData `json:"images,omitempty"`
LogProbs int `json:"logprobs,omitempty"`
// Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it.
Options map[string]interface{} `json:"options"`
@ -103,6 +105,8 @@ type ChatRequest struct {
// Tools is an optional list of tools the model has access to.
Tools `json:"tools,omitempty"`
LogProbs int `json:"logprobs,omitempty"`
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
}
@ -182,13 +186,20 @@ func (t *ToolFunction) String() string {
return string(bts)
}
type TokenProbs struct {
TokenID int `json:"id"`
LogProb float32 `json:"logprob"`
Token string `json:"token"`
}
// ChatResponse is the response returned by [Client.Chat]. Its fields are
// similar to [GenerateResponse].
type ChatResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Message Message `json:"message"`
DoneReason string `json:"done_reason,omitempty"`
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Message Message `json:"message"`
DoneReason string `json:"done_reason,omitempty"`
LogProbs []TokenProbs `json:"logprobs,omitempty"`
Done bool `json:"done"`
@ -452,6 +463,8 @@ type GenerateResponse struct {
// can be sent in the next request to keep a conversational memory.
Context []int `json:"context,omitempty"`
LogProbs []TokenProbs `json:"logprobs,omitempty"`
Metrics
}

View File

@ -50,7 +50,7 @@ import (
_ "github.com/ollama/ollama/llama/llama.cpp/common"
_ "github.com/ollama/ollama/llama/llama.cpp/examples/llava"
_ "github.com/ollama/ollama/llama/llama.cpp/src"
"github.com/ollama/ollama/ml/backend/ggml/ggml/src"
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
)
func BackendInit() {
@ -220,6 +220,19 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 {
return embeddings
}
// GetLogits returns the logits from the last decode operation.
// The returned slice has length equal to the vocabulary size.
func (c *Context) GetLogits() []float32 {
logits := unsafe.Pointer(C.llama_get_logits(c.c))
if logits == nil {
return nil
}
// Get the number of vocabulary tokens to determine array size
vocabSize := c.Model().NumVocab()
return unsafe.Slice((*float32)(logits), vocabSize)
}
type ModelParams struct {
NumGpuLayers int
MainGpu int

View File

@ -8,12 +8,14 @@ import (
"fmt"
"log"
"log/slog"
"math"
"net"
"net/http"
"os"
"path/filepath"
"regexp"
"runtime"
"sort"
"strconv"
"strings"
"sync"
@ -48,8 +50,9 @@ type Sequence struct {
// inputs that have been added to a batch but not yet submitted to Decode
pendingInputs []input
// TODO: update this comment
// tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string
pendingResponses []CompletionResponse
// input cache being used by this sequence
cache *InputCacheSlot
@ -59,7 +62,7 @@ type Sequence struct {
crossAttention bool
// channel to send responses over
responses chan string
responses chan CompletionResponse
// channel to stop decoding (such as if the remote connection is closed)
quit chan bool
@ -83,6 +86,11 @@ type Sequence struct {
doneReason string
logits []float32
// number of logprobs to return with the completion response
logprobs int
// Metrics
startProcessingTime time.Time
startGenerationTime time.Time
@ -96,6 +104,7 @@ type NewSequenceParams struct {
numKeep int
samplingParams *llama.SamplingParams
embedding bool
logprobs int
}
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
@ -148,14 +157,15 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
numPromptInputs: len(inputs),
startProcessingTime: startTime,
numPredict: params.numPredict,
pendingResponses: make([]string, 0),
responses: make(chan string, 100),
pendingResponses: make([]CompletionResponse, 0),
responses: make(chan CompletionResponse, 100),
quit: make(chan bool, 1),
embedding: make(chan []float32, 1),
samplingCtx: sc,
embeddingOnly: params.embedding,
stop: params.stop,
numKeep: params.numKeep,
logprobs: params.logprobs,
}, nil
}
@ -274,29 +284,37 @@ func (s *Server) allNil() bool {
}
func flushPending(seq *Sequence) bool {
joined := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = []string{}
// Check if there are any partial UTF-8 characters remaining.
// We already check and queue as we are generating but some may
// still make it here:
// - Sequence is ending, e.g. generation limit has been hit
// - Invalid characters in the middle of a string
// This is a stricter check to ensure we never output invalid Unicode.
for !utf8.ValidString(joined) {
joined = joined[:len(joined)-1]
}
if len(joined) == 0 {
if len(seq.pendingResponses) == 0 {
return true
}
select {
case seq.responses <- joined:
return true
case <-seq.quit:
return false
resps := []CompletionResponse{}
for _, resp := range seq.pendingResponses {
resps = append(resps, resp)
}
seq.pendingResponses = []CompletionResponse{}
// TODO: figure out this result logic
result := false
for _, resp := range resps {
// Check if there are any partial UTF-8 characters remaining.
// We already check and queue as we are generating but some may
// still make it here:
// - Sequence is ending, e.g. generation limit has been hit
// - Invalid characters in the middle of a string
// This is a stricter check to ensure we never output invalid Unicode.
for !utf8.ValidString(resp.Content) {
resp.Content = resp.Content[:len(resp.Content)-1]
}
select {
case seq.responses <- resp:
result = true
case <-seq.quit:
result = false
}
}
return result
}
func (s *Server) removeSequence(seqIndex int, reason string) {
@ -350,6 +368,63 @@ func (s *Server) run(ctx context.Context) {
}
}
// TokenProbs represents probability information for a token
type TokenProbs struct {
TokenID int `json:"id"`
Logit float32 `json:"logit"`
Prob float32 `json:"prob"`
LogProb float32 `json:"logprob"`
Token string `json:"token"`
}
// probs returns sorted token probabilities for a specific token index
func probs(logits []float32, vocabSize int) []TokenProbs {
probs := make([]TokenProbs, vocabSize)
// Initialize token data with logits
for i := 0; i < vocabSize; i++ {
probs[i] = TokenProbs{
TokenID: i,
Logit: logits[i],
}
}
// Sort tokens by logits in descending order
sort.Slice(probs, func(i, j int) bool {
return probs[i].Logit > probs[j].Logit
})
// Apply softmax
maxLogit := probs[0].Logit
var sum float32 = 0.0
for i := range probs {
p := float32(math.Exp(float64(probs[i].Logit - maxLogit)))
probs[i].Prob = p
sum += p
}
// Normalize probabilities and calculate log probs
for i := range probs {
prob := probs[i].Prob / sum
probs[i].Prob = prob
probs[i].LogProb = float32(math.Log(float64(prob)))
}
return probs
}
// probs returns sorted token probabilities for a specific token index
func (s *Server) probs(seq *Sequence) []TokenProbs {
// Get logits for the specific token index
logits := s.lc.GetLogits()
seq.logits = make([]float32, len(logits))
copy(seq.logits, logits)
vocabSize := s.model.NumVocab()
return probs(logits, vocabSize)
}
// TODO (jmorganca): processBatch should be simplified, removing:
// * sampling
// * stop token checking
@ -483,6 +558,19 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
seq.numPredicted++
resp := CompletionResponse{Content: piece}
if seq.logprobs > 0 {
// TODO: return selected token in logprobs always
resp.LogProbs = s.probs(seq)
// TODO: fix this logprobs limit
resp.LogProbs = resp.LogProbs[:min(len(resp.LogProbs), seq.logprobs)]
for i := range resp.LogProbs {
// decode the token id to a piece
resp.LogProbs[i].Token = s.model.TokenToPiece(resp.LogProbs[i].TokenID)
}
}
// if it's an end of sequence token, break
if s.model.TokenIsEog(token) {
// TODO (jmorganca): we should send this back
@ -495,16 +583,21 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
seq.inputs = []input{{token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "")
// TODO: add probs here
seq.pendingResponses = append(seq.pendingResponses, resp)
var sequence string
for _, r := range seq.pendingResponses {
sequence += r.Content
}
if ok, stop := findStop(sequence, seq.stop); ok {
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
// TODO: fix this stop sequence caching
var tokenTruncated bool
origLen := len(seq.pendingResponses)
seq.pendingResponses, tokenTruncated = truncateStop(seq.pendingResponses, stop)
newLen := len(seq.pendingResponses)
origLen := len(sequence)
sequence, tokenTruncated = truncateStop(sequence, stop)
newLen := len(sequence)
// Update the cache based on the tokens that will be returned:
// - We have 1 token more than is currently in the cache because
@ -575,6 +668,7 @@ type CompletionRequest struct {
Images []ImageData `json:"image_data"`
Grammar string `json:"grammar"`
CachePrompt bool `json:"cache_prompt"`
Logprobs int `json:"logprobs,omitempty"`
Options
}
@ -590,8 +684,10 @@ type CompletionResponse struct {
Content string `json:"content"`
Stop bool `json:"stop"`
Model string `json:"model,omitempty"`
Prompt string `json:"prompt,omitempty"`
Model string `json:"model,omitempty"`
Prompt string `json:"prompt,omitempty"`
LogProbs []TokenProbs `json:"logprobs,omitempty"`
StoppedLimit bool `json:"stopped_limit,omitempty"`
PredictedN int `json:"predicted_n,omitempty"`
PredictedMS float64 `json:"predicted_ms,omitempty"`
@ -609,10 +705,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
// Set the headers to indicate streaming
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Transfer-Encoding", "chunked")
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
@ -641,6 +733,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
numKeep: req.NumKeep,
samplingParams: &samplingParams,
embedding: false,
logprobs: req.Logprobs,
})
if err != nil {
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
@ -688,11 +781,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
case <-r.Context().Done():
close(seq.quit)
return
case content, ok := <-seq.responses:
case resp, ok := <-seq.responses:
if ok {
if err := json.NewEncoder(w).Encode(&CompletionResponse{
Content: content,
}); err != nil {
fmt.Println("response", resp)
if err := json.NewEncoder(w).Encode(&resp); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
close(seq.quit)
return

View File

@ -0,0 +1,58 @@
package runner
import (
"math"
"testing"
)
func TestProbs(t *testing.T) {
// Input test data
logits := []float32{1.0, 2.0, 0.5, -1.0}
vocabSize := 4
want := []TokenProbs{
{TokenID: 1, Logit: 2.0}, // Highest logit
{TokenID: 0, Logit: 1.0}, // Second highest
{TokenID: 2, Logit: 0.5}, // Third
{TokenID: 3, Logit: -1.0}, // Lowest
}
got := probs(logits, vocabSize)
// Test 1: Check sorting order
for i := 0; i < len(got)-1; i++ {
if got[i].Logit < got[i+1].Logit {
t.Errorf("probs not properly sorted: logit at pos %d (%f) < logit at pos %d (%f)",
i, got[i].Logit, i+1, got[i+1].Logit)
}
}
// Test 2: Check probability normalization
var sum float32
for _, p := range got {
sum += p.Prob
}
if math.Abs(float64(sum-1.0)) > 1e-6 {
t.Errorf("probabilities do not sum to 1: got %v", sum)
}
// Test 3: Check token IDs match expected order
for i, want := range want {
if got[i].TokenID != want.TokenID {
t.Errorf("wrong token ID at position %d: got %d, want %d",
i, got[i].TokenID, want.TokenID)
}
if got[i].Logit != want.Logit {
t.Errorf("wrong logit at position %d: got %f, want %f",
i, got[i].Logit, want.Logit)
}
}
// Test 4: Check log probs are correctly calculated
for i, p := range got {
expectedLogProb := float32(math.Log(float64(p.Prob)))
if math.Abs(float64(p.LogProb-expectedLogProb)) > 1e-6 {
t.Errorf("wrong log prob at position %d: got %f, want %f",
i, p.LogProb, expectedLogProb)
}
}
}

View File

@ -26,43 +26,15 @@ func containsStopSuffix(sequence string, stops []string) bool {
return false
}
// truncateStop removes the provided stop string from pieces,
// returning the partial pieces with stop removed, including truncating
// the last piece if required (and signalling if this was the case)
func truncateStop(pieces []string, stop string) ([]string, bool) {
joined := strings.Join(pieces, "")
index := strings.Index(joined, stop)
// truncateStop removes the provided stop string from sequence,
// returning both the truncated sequence and a bool indicating if truncation occurred
func truncateStop(sequence string, stop string) (string, bool) {
index := strings.Index(sequence, stop)
if index == -1 {
return pieces, false
return sequence, false
}
joined = joined[:index]
// Split truncated string back into pieces of original lengths
lengths := make([]int, len(pieces))
for i, piece := range pieces {
lengths[i] = len(piece)
}
var result []string
tokenTruncated := false
start := 0
for _, length := range lengths {
if start >= len(joined) {
break
}
end := start + length
if end > len(joined) {
end = len(joined)
tokenTruncated = true
}
result = append(result, joined[start:end])
start = end
}
return result, tokenTruncated
return sequence[:index], true
}
func incompleteUnicode(token string) bool {

View File

@ -1,60 +1,60 @@
package runner
import (
"reflect"
"testing"
)
func TestTruncateStop(t *testing.T) {
tests := []struct {
name string
pieces []string
sequence string
stop string
expected []string
expected string
expectedTrunc bool
}{
{
name: "Single word",
pieces: []string{"hello", "world"},
sequence: "helloworld",
stop: "world",
expected: []string{"hello"},
expectedTrunc: false,
expected: "hello",
expectedTrunc: true,
},
{
name: "Partial",
pieces: []string{"hello", "wor"},
sequence: "hellowor",
stop: "or",
expected: []string{"hello", "w"},
expected: "hellow",
expectedTrunc: true,
},
{
name: "Suffix",
pieces: []string{"Hello", " there", "!"},
sequence: "Hello there!",
stop: "!",
expected: []string{"Hello", " there"},
expectedTrunc: false,
},
{
name: "Suffix partial",
pieces: []string{"Hello", " the", "re!"},
stop: "there!",
expected: []string{"Hello", " "},
expected: "Hello there",
expectedTrunc: true,
},
{
name: "Middle",
pieces: []string{"hello", " wor"},
sequence: "hello wor",
stop: "llo w",
expected: []string{"he"},
expected: "he",
expectedTrunc: true,
},
{
name: "No stop found",
sequence: "hello world",
stop: "xyz",
expected: "hello world",
expectedTrunc: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, resultTrunc := truncateStop(tt.pieces, tt.stop)
if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc)
result, truncated := truncateStop(tt.sequence, tt.stop)
if result != tt.expected || truncated != tt.expectedTrunc {
t.Errorf("truncateStop(%q, %q): have %q (%v); want %q (%v)",
tt.sequence, tt.stop, result, truncated, tt.expected, tt.expectedTrunc)
}
})
}

View File

@ -644,12 +644,22 @@ type ImageData struct {
AspectRatioID int `json:"aspect_ratio_id"`
}
// TokenProbs represents probability information for a token
type TokenProbs struct {
TokenID int `json:"id"`
Logit float32 `json:"logit"`
Prob float32 `json:"prob"`
LogProb float32 `json:"logprob"`
Token string `json:"token"`
}
type completion struct {
Content string `json:"content"`
Model string `json:"model"`
Prompt string `json:"prompt"`
Stop bool `json:"stop"`
StoppedLimit bool `json:"stopped_limit"`
Content string `json:"content"`
Model string `json:"model"`
Prompt string `json:"prompt"`
Stop bool `json:"stop"`
StoppedLimit bool `json:"stopped_limit"`
LogProbs []TokenProbs `json:"logprobs"`
Timings struct {
PredictedN int `json:"predicted_n"`
@ -660,14 +670,16 @@ type completion struct {
}
type CompletionRequest struct {
Prompt string
Format json.RawMessage
Images []ImageData
Options *api.Options
Prompt string
Format json.RawMessage
Images []ImageData
LogProbs int
Options *api.Options
}
type CompletionResponse struct {
Content string
LogProbs []TokenProbs
DoneReason string
Done bool
PromptEvalCount int
@ -698,9 +710,12 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
"seed": req.Options.Seed,
"stop": req.Options.Stop,
"image_data": req.Images,
"logprobs": req.LogProbs,
"cache_prompt": true,
}
fmt.Println("completion request:", request)
if len(req.Format) > 0 {
switch string(req.Format) {
case `null`, `""`:
@ -796,7 +811,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
continue
}
// slog.Debug("got line", "line", string(line))
evt, ok := bytes.CutPrefix(line, []byte("data: "))
if !ok {
evt = line
@ -822,7 +836,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
if c.Content != "" {
fn(CompletionResponse{
Content: c.Content,
Content: c.Content,
LogProbs: c.LogProbs,
})
}
@ -839,6 +854,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
EvalCount: c.Timings.PredictedN,
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
LogProbs: c.LogProbs,
})
return nil
}

View File

@ -293,11 +293,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var sb strings.Builder
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
Prompt: prompt,
Images: images,
Format: req.Format,
LogProbs: req.LogProbs,
Options: opts,
}, func(cr llm.CompletionResponse) {
fmt.Printf("banana: %#v\n", cr)
res := api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
@ -311,6 +313,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
EvalDuration: cr.EvalDuration,
},
}
for _, p := range cr.LogProbs {
res.LogProbs = append(res.LogProbs, api.TokenProbs{
TokenID: p.TokenID,
LogProb: p.LogProb,
Token: p.Token,
})
}
if _, err := sb.WriteString(cr.Content); err != nil {
ch <- gin.H{"error": err.Error()}
@ -1466,10 +1475,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
var sb strings.Builder
var toolCallIndex int = 0
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
Prompt: prompt,
Images: images,
Format: req.Format,
LogProbs: req.LogProbs,
Options: opts,
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
Model: req.Model,
@ -1484,6 +1494,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
EvalDuration: r.EvalDuration,
},
}
for _, p := range r.LogProbs {
res.LogProbs = append(res.LogProbs, api.TokenProbs{
TokenID: p.TokenID,
LogProb: p.LogProb,
Token: p.Token,
})
}
if r.Done {
res.TotalDuration = time.Since(checkpointStart)