This commit is contained in:
Bruce MacDonald 2025-02-13 14:02:04 -08:00
parent 6dfcdec2da
commit 64f95067ba
3 changed files with 105 additions and 62 deletions

View File

@ -50,8 +50,9 @@ type Sequence struct {
// inputs that have been added to a batch but not yet submitted to Decode // inputs that have been added to a batch but not yet submitted to Decode
pendingInputs []input pendingInputs []input
// TODO: update this comment
// tokens that have been generated but not returned yet (e.g. for stop sequences) // tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string pendingResponses []CompletionResponse
// input cache being used by this sequence // input cache being used by this sequence
cache *InputCacheSlot cache *InputCacheSlot
@ -87,6 +88,9 @@ type Sequence struct {
logits []float32 logits []float32
// number of logprobs to return with the completion response
logprobs int
// Metrics // Metrics
startProcessingTime time.Time startProcessingTime time.Time
startGenerationTime time.Time startGenerationTime time.Time
@ -152,7 +156,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
numPromptInputs: len(inputs), numPromptInputs: len(inputs),
startProcessingTime: startTime, startProcessingTime: startTime,
numPredict: params.numPredict, numPredict: params.numPredict,
pendingResponses: make([]string, 0), pendingResponses: make([]CompletionResponse, 0),
responses: make(chan CompletionResponse, 100), responses: make(chan CompletionResponse, 100),
quit: make(chan bool, 1), quit: make(chan bool, 1),
embedding: make(chan []float32, 1), embedding: make(chan []float32, 1),
@ -281,8 +285,11 @@ func flushPending(seq *Sequence) bool {
if len(seq.pendingResponses) == 0 { if len(seq.pendingResponses) == 0 {
return true return true
} }
content := strings.Join(seq.pendingResponses, "") content := ""
seq.pendingResponses = []string{} for _, resp := range seq.pendingResponses {
content += resp.Content
}
seq.pendingResponses = []CompletionResponse{}
// Check if there are any partial UTF-8 characters remaining. // Check if there are any partial UTF-8 characters remaining.
// We already check and queue as we are generating but some may // We already check and queue as we are generating but some may
@ -362,27 +369,27 @@ func (s *Server) run(ctx context.Context) {
} }
} }
// TokenData represents probability information for a token // TokenProbs represents probability information for a token
type TokenData struct { type TokenProbs struct {
TokenID int TokenID int
Logit float32 Logit float32
Prob float32 Prob float32
LogProb float32 LogProb float32
} }
// getTokenProbabilities returns sorted token probabilities for a specific token index // probs returns sorted token probabilities for a specific token index
func (s *Server) getTokenProbabilities(seq *Sequence) []TokenData { func (s *Server) probs(seq *Sequence) []TokenProbs {
// Get logits for the specific token index // Get logits for the specific token index
logits := s.lc.GetLogits() logits := s.lc.GetLogits()
seq.logits = make([]float32, len(logits)) seq.logits = make([]float32, len(logits))
copy(seq.logits, logits) copy(seq.logits, logits)
vocabSize := s.model.NumVocab() vocabSize := s.model.NumVocab()
probs := make([]TokenData, vocabSize) probs := make([]TokenProbs, vocabSize)
// Initialize token data with logits // Initialize token data with logits
for i := 0; i < vocabSize; i++ { for i := 0; i < vocabSize; i++ {
probs[i] = TokenData{ probs[i] = TokenProbs{
TokenID: i, TokenID: i,
Logit: logits[i], Logit: logits[i],
} }
@ -546,10 +553,9 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
seq.numPredicted++ seq.numPredicted++
// TODO: only do this when flag specified if seq.logprobs > 0 {
probs := s.getTokenProbabilities(seq) // TODO: return selected token in logprobs always
for i := range 10 { // probs := s.probs(seq)
slog.Debug("top 10 tokens", "token", probs[i].TokenID, "prob", probs[i].Prob, "logit", probs[i].Logit, "piece", s.model.TokenToPiece(probs[i].TokenID))
} }
// if it's an end of sequence token, break // if it's an end of sequence token, break
@ -564,8 +570,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
seq.inputs = []input{{token: token}} seq.inputs = []input{{token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece) // TODO: add probs here
sequence := strings.Join(seq.pendingResponses, "") seq.pendingResponses = append(seq.pendingResponses, CompletionResponse{Content: piece})
var sequence string
for _, r := range seq.pendingResponses {
sequence += r.Content
}
if ok, stop := findStop(sequence, seq.stop); ok { if ok, stop := findStop(sequence, seq.stop); ok {
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop) slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)

View File

@ -29,40 +29,43 @@ func containsStopSuffix(sequence string, stops []string) bool {
// truncateStop removes the provided stop string from pieces, // truncateStop removes the provided stop string from pieces,
// returning the partial pieces with stop removed, including truncating // returning the partial pieces with stop removed, including truncating
// the last piece if required (and signalling if this was the case) // the last piece if required (and signalling if this was the case)
func truncateStop(pieces []string, stop string) ([]string, bool) { func truncateStop(pieces []CompletionResponse, stop string) ([]CompletionResponse, bool) {
joined := strings.Join(pieces, "") // Build complete string and find stop position
var completeStr string
for _, piece := range pieces {
completeStr += piece.Content
}
index := strings.Index(joined, stop) stopStart := strings.Index(completeStr, stop)
if index == -1 { if stopStart == -1 {
return pieces, false return pieces, false
} }
joined = joined[:index] // Build result up to stop position
result := make([]CompletionResponse, 0)
accumulated := 0
// Split truncated string back into pieces of original lengths truncated := false
lengths := make([]int, len(pieces)) for _, piece := range pieces {
for i, piece := range pieces { if accumulated+len(piece.Content) <= stopStart {
lengths[i] = len(piece) result = append(result, piece)
} accumulated += len(piece.Content)
continue
var result []string
tokenTruncated := false
start := 0
for _, length := range lengths {
if start >= len(joined) {
break
} }
end := start + length if accumulated < stopStart {
if end > len(joined) { truncPiece := piece
end = len(joined) truncPiece.Content = piece.Content[:stopStart-accumulated]
tokenTruncated = true if len(truncPiece.Content) > 0 {
result = append(result, truncPiece)
truncated = true
}
} }
result = append(result, joined[start:end]) break
start = end
} }
return result, tokenTruncated // Signal if we had to truncate the last piece
return result, truncated
} }
func incompleteUnicode(token string) bool { func incompleteUnicode(token string) bool {

View File

@ -8,44 +8,74 @@ import (
func TestTruncateStop(t *testing.T) { func TestTruncateStop(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
pieces []string pieces []CompletionResponse
stop string stop string
expected []string expected []CompletionResponse
expectedTrunc bool expectedTrunc bool
}{ }{
{ {
name: "Single word", name: "Single word",
pieces: []string{"hello", "world"}, pieces: []CompletionResponse{
stop: "world", {Content: "hello"},
expected: []string{"hello"}, {Content: "world"},
},
stop: "world",
expected: []CompletionResponse{
{Content: "hello"},
},
expectedTrunc: false, expectedTrunc: false,
}, },
{ {
name: "Partial", name: "Partial",
pieces: []string{"hello", "wor"}, pieces: []CompletionResponse{
stop: "or", {Content: "hello"},
expected: []string{"hello", "w"}, {Content: "wor"},
},
stop: "or",
expected: []CompletionResponse{
{Content: "hello"},
{Content: "w"},
},
expectedTrunc: true, expectedTrunc: true,
}, },
{ {
name: "Suffix", name: "Suffix",
pieces: []string{"Hello", " there", "!"}, pieces: []CompletionResponse{
stop: "!", {Content: "Hello"},
expected: []string{"Hello", " there"}, {Content: " there"},
{Content: "!"},
},
stop: "!",
expected: []CompletionResponse{
{Content: "Hello"},
{Content: " there"},
},
expectedTrunc: false, expectedTrunc: false,
}, },
{ {
name: "Suffix partial", name: "Suffix partial",
pieces: []string{"Hello", " the", "re!"}, pieces: []CompletionResponse{
stop: "there!", {Content: "Hello"},
expected: []string{"Hello", " "}, {Content: " the"},
{Content: "re!"},
},
stop: "there!",
expected: []CompletionResponse{
{Content: "Hello"},
{Content: " "},
},
expectedTrunc: true, expectedTrunc: true,
}, },
{ {
name: "Middle", name: "Middle",
pieces: []string{"hello", " wor"}, pieces: []CompletionResponse{
stop: "llo w", {Content: "hello"},
expected: []string{"he"}, {Content: " wor"},
},
stop: "llo w",
expected: []CompletionResponse{
{Content: "he"},
},
expectedTrunc: true, expectedTrunc: true,
}, },
} }