...
This commit is contained in:
parent
6dfcdec2da
commit
64f95067ba
@ -50,8 +50,9 @@ type Sequence struct {
|
||||
// inputs that have been added to a batch but not yet submitted to Decode
|
||||
pendingInputs []input
|
||||
|
||||
// TODO: update this comment
|
||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||
pendingResponses []string
|
||||
pendingResponses []CompletionResponse
|
||||
|
||||
// input cache being used by this sequence
|
||||
cache *InputCacheSlot
|
||||
@ -87,6 +88,9 @@ type Sequence struct {
|
||||
|
||||
logits []float32
|
||||
|
||||
// number of logprobs to return with the completion response
|
||||
logprobs int
|
||||
|
||||
// Metrics
|
||||
startProcessingTime time.Time
|
||||
startGenerationTime time.Time
|
||||
@ -152,7 +156,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||
numPromptInputs: len(inputs),
|
||||
startProcessingTime: startTime,
|
||||
numPredict: params.numPredict,
|
||||
pendingResponses: make([]string, 0),
|
||||
pendingResponses: make([]CompletionResponse, 0),
|
||||
responses: make(chan CompletionResponse, 100),
|
||||
quit: make(chan bool, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
@ -281,8 +285,11 @@ func flushPending(seq *Sequence) bool {
|
||||
if len(seq.pendingResponses) == 0 {
|
||||
return true
|
||||
}
|
||||
content := strings.Join(seq.pendingResponses, "")
|
||||
seq.pendingResponses = []string{}
|
||||
content := ""
|
||||
for _, resp := range seq.pendingResponses {
|
||||
content += resp.Content
|
||||
}
|
||||
seq.pendingResponses = []CompletionResponse{}
|
||||
|
||||
// Check if there are any partial UTF-8 characters remaining.
|
||||
// We already check and queue as we are generating but some may
|
||||
@ -362,27 +369,27 @@ func (s *Server) run(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// TokenData represents probability information for a token
|
||||
type TokenData struct {
|
||||
// TokenProbs represents probability information for a token
|
||||
type TokenProbs struct {
|
||||
TokenID int
|
||||
Logit float32
|
||||
Prob float32
|
||||
LogProb float32
|
||||
}
|
||||
|
||||
// getTokenProbabilities returns sorted token probabilities for a specific token index
|
||||
func (s *Server) getTokenProbabilities(seq *Sequence) []TokenData {
|
||||
// probs returns sorted token probabilities for a specific token index
|
||||
func (s *Server) probs(seq *Sequence) []TokenProbs {
|
||||
// Get logits for the specific token index
|
||||
logits := s.lc.GetLogits()
|
||||
seq.logits = make([]float32, len(logits))
|
||||
copy(seq.logits, logits)
|
||||
|
||||
vocabSize := s.model.NumVocab()
|
||||
probs := make([]TokenData, vocabSize)
|
||||
probs := make([]TokenProbs, vocabSize)
|
||||
|
||||
// Initialize token data with logits
|
||||
for i := 0; i < vocabSize; i++ {
|
||||
probs[i] = TokenData{
|
||||
probs[i] = TokenProbs{
|
||||
TokenID: i,
|
||||
Logit: logits[i],
|
||||
}
|
||||
@ -546,10 +553,9 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
|
||||
seq.numPredicted++
|
||||
|
||||
// TODO: only do this when flag specified
|
||||
probs := s.getTokenProbabilities(seq)
|
||||
for i := range 10 {
|
||||
slog.Debug("top 10 tokens", "token", probs[i].TokenID, "prob", probs[i].Prob, "logit", probs[i].Logit, "piece", s.model.TokenToPiece(probs[i].TokenID))
|
||||
if seq.logprobs > 0 {
|
||||
// TODO: return selected token in logprobs always
|
||||
// probs := s.probs(seq)
|
||||
}
|
||||
|
||||
// if it's an end of sequence token, break
|
||||
@ -564,8 +570,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
|
||||
seq.inputs = []input{{token: token}}
|
||||
|
||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||
sequence := strings.Join(seq.pendingResponses, "")
|
||||
// TODO: add probs here
|
||||
seq.pendingResponses = append(seq.pendingResponses, CompletionResponse{Content: piece})
|
||||
var sequence string
|
||||
for _, r := range seq.pendingResponses {
|
||||
sequence += r.Content
|
||||
}
|
||||
|
||||
if ok, stop := findStop(sequence, seq.stop); ok {
|
||||
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
||||
|
@ -29,40 +29,43 @@ func containsStopSuffix(sequence string, stops []string) bool {
|
||||
// truncateStop removes the provided stop string from pieces,
|
||||
// returning the partial pieces with stop removed, including truncating
|
||||
// the last piece if required (and signalling if this was the case)
|
||||
func truncateStop(pieces []string, stop string) ([]string, bool) {
|
||||
joined := strings.Join(pieces, "")
|
||||
func truncateStop(pieces []CompletionResponse, stop string) ([]CompletionResponse, bool) {
|
||||
// Build complete string and find stop position
|
||||
var completeStr string
|
||||
for _, piece := range pieces {
|
||||
completeStr += piece.Content
|
||||
}
|
||||
|
||||
index := strings.Index(joined, stop)
|
||||
if index == -1 {
|
||||
stopStart := strings.Index(completeStr, stop)
|
||||
if stopStart == -1 {
|
||||
return pieces, false
|
||||
}
|
||||
|
||||
joined = joined[:index]
|
||||
// Build result up to stop position
|
||||
result := make([]CompletionResponse, 0)
|
||||
accumulated := 0
|
||||
|
||||
// Split truncated string back into pieces of original lengths
|
||||
lengths := make([]int, len(pieces))
|
||||
for i, piece := range pieces {
|
||||
lengths[i] = len(piece)
|
||||
}
|
||||
|
||||
var result []string
|
||||
tokenTruncated := false
|
||||
start := 0
|
||||
for _, length := range lengths {
|
||||
if start >= len(joined) {
|
||||
break
|
||||
truncated := false
|
||||
for _, piece := range pieces {
|
||||
if accumulated+len(piece.Content) <= stopStart {
|
||||
result = append(result, piece)
|
||||
accumulated += len(piece.Content)
|
||||
continue
|
||||
}
|
||||
|
||||
end := start + length
|
||||
if end > len(joined) {
|
||||
end = len(joined)
|
||||
tokenTruncated = true
|
||||
if accumulated < stopStart {
|
||||
truncPiece := piece
|
||||
truncPiece.Content = piece.Content[:stopStart-accumulated]
|
||||
if len(truncPiece.Content) > 0 {
|
||||
result = append(result, truncPiece)
|
||||
truncated = true
|
||||
}
|
||||
}
|
||||
result = append(result, joined[start:end])
|
||||
start = end
|
||||
break
|
||||
}
|
||||
|
||||
return result, tokenTruncated
|
||||
// Signal if we had to truncate the last piece
|
||||
return result, truncated
|
||||
}
|
||||
|
||||
func incompleteUnicode(token string) bool {
|
||||
|
@ -8,44 +8,74 @@ import (
|
||||
func TestTruncateStop(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pieces []string
|
||||
pieces []CompletionResponse
|
||||
stop string
|
||||
expected []string
|
||||
expected []CompletionResponse
|
||||
expectedTrunc bool
|
||||
}{
|
||||
{
|
||||
name: "Single word",
|
||||
pieces: []string{"hello", "world"},
|
||||
stop: "world",
|
||||
expected: []string{"hello"},
|
||||
name: "Single word",
|
||||
pieces: []CompletionResponse{
|
||||
{Content: "hello"},
|
||||
{Content: "world"},
|
||||
},
|
||||
stop: "world",
|
||||
expected: []CompletionResponse{
|
||||
{Content: "hello"},
|
||||
},
|
||||
expectedTrunc: false,
|
||||
},
|
||||
{
|
||||
name: "Partial",
|
||||
pieces: []string{"hello", "wor"},
|
||||
stop: "or",
|
||||
expected: []string{"hello", "w"},
|
||||
name: "Partial",
|
||||
pieces: []CompletionResponse{
|
||||
{Content: "hello"},
|
||||
{Content: "wor"},
|
||||
},
|
||||
stop: "or",
|
||||
expected: []CompletionResponse{
|
||||
{Content: "hello"},
|
||||
{Content: "w"},
|
||||
},
|
||||
expectedTrunc: true,
|
||||
},
|
||||
{
|
||||
name: "Suffix",
|
||||
pieces: []string{"Hello", " there", "!"},
|
||||
stop: "!",
|
||||
expected: []string{"Hello", " there"},
|
||||
name: "Suffix",
|
||||
pieces: []CompletionResponse{
|
||||
{Content: "Hello"},
|
||||
{Content: " there"},
|
||||
{Content: "!"},
|
||||
},
|
||||
stop: "!",
|
||||
expected: []CompletionResponse{
|
||||
{Content: "Hello"},
|
||||
{Content: " there"},
|
||||
},
|
||||
expectedTrunc: false,
|
||||
},
|
||||
{
|
||||
name: "Suffix partial",
|
||||
pieces: []string{"Hello", " the", "re!"},
|
||||
stop: "there!",
|
||||
expected: []string{"Hello", " "},
|
||||
name: "Suffix partial",
|
||||
pieces: []CompletionResponse{
|
||||
{Content: "Hello"},
|
||||
{Content: " the"},
|
||||
{Content: "re!"},
|
||||
},
|
||||
stop: "there!",
|
||||
expected: []CompletionResponse{
|
||||
{Content: "Hello"},
|
||||
{Content: " "},
|
||||
},
|
||||
expectedTrunc: true,
|
||||
},
|
||||
{
|
||||
name: "Middle",
|
||||
pieces: []string{"hello", " wor"},
|
||||
stop: "llo w",
|
||||
expected: []string{"he"},
|
||||
name: "Middle",
|
||||
pieces: []CompletionResponse{
|
||||
{Content: "hello"},
|
||||
{Content: " wor"},
|
||||
},
|
||||
stop: "llo w",
|
||||
expected: []CompletionResponse{
|
||||
{Content: "he"},
|
||||
},
|
||||
expectedTrunc: true,
|
||||
},
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user