diff --git a/llama/runner/runner.go b/llama/runner/runner.go index eda3d109b..2f15f4bec 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -50,8 +50,9 @@ type Sequence struct { // inputs that have been added to a batch but not yet submitted to Decode pendingInputs []input + // TODO: update this comment // tokens that have been generated but not returned yet (e.g. for stop sequences) - pendingResponses []string + pendingResponses []CompletionResponse // input cache being used by this sequence cache *InputCacheSlot @@ -87,6 +88,9 @@ type Sequence struct { logits []float32 + // number of logprobs to return with the completion response + logprobs int + // Metrics startProcessingTime time.Time startGenerationTime time.Time @@ -152,7 +156,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen numPromptInputs: len(inputs), startProcessingTime: startTime, numPredict: params.numPredict, - pendingResponses: make([]string, 0), + pendingResponses: make([]CompletionResponse, 0), responses: make(chan CompletionResponse, 100), quit: make(chan bool, 1), embedding: make(chan []float32, 1), @@ -281,8 +285,11 @@ func flushPending(seq *Sequence) bool { if len(seq.pendingResponses) == 0 { return true } - content := strings.Join(seq.pendingResponses, "") - seq.pendingResponses = []string{} + content := "" + for _, resp := range seq.pendingResponses { + content += resp.Content + } + seq.pendingResponses = []CompletionResponse{} // Check if there are any partial UTF-8 characters remaining. // We already check and queue as we are generating but some may @@ -362,27 +369,27 @@ func (s *Server) run(ctx context.Context) { } } -// TokenData represents probability information for a token -type TokenData struct { +// TokenProbs represents probability information for a token +type TokenProbs struct { TokenID int Logit float32 Prob float32 LogProb float32 } -// getTokenProbabilities returns sorted token probabilities for a specific token index -func (s *Server) getTokenProbabilities(seq *Sequence) []TokenData { +// probs returns sorted token probabilities for a specific token index +func (s *Server) probs(seq *Sequence) []TokenProbs { // Get logits for the specific token index logits := s.lc.GetLogits() seq.logits = make([]float32, len(logits)) copy(seq.logits, logits) vocabSize := s.model.NumVocab() - probs := make([]TokenData, vocabSize) + probs := make([]TokenProbs, vocabSize) // Initialize token data with logits for i := 0; i < vocabSize; i++ { - probs[i] = TokenData{ + probs[i] = TokenProbs{ TokenID: i, Logit: logits[i], } @@ -546,10 +553,9 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) seq.numPredicted++ - // TODO: only do this when flag specified - probs := s.getTokenProbabilities(seq) - for i := range 10 { - slog.Debug("top 10 tokens", "token", probs[i].TokenID, "prob", probs[i].Prob, "logit", probs[i].Logit, "piece", s.model.TokenToPiece(probs[i].TokenID)) + if seq.logprobs > 0 { + // TODO: return selected token in logprobs always + // probs := s.probs(seq) } // if it's an end of sequence token, break @@ -564,8 +570,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) seq.inputs = []input{{token: token}} - seq.pendingResponses = append(seq.pendingResponses, piece) - sequence := strings.Join(seq.pendingResponses, "") + // TODO: add probs here + seq.pendingResponses = append(seq.pendingResponses, CompletionResponse{Content: piece}) + var sequence string + for _, r := range seq.pendingResponses { + sequence += r.Content + } if ok, stop := findStop(sequence, seq.stop); ok { slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop) diff --git a/llama/runner/stop.go b/llama/runner/stop.go index 8dcb08d33..76f5cb688 100644 --- a/llama/runner/stop.go +++ b/llama/runner/stop.go @@ -29,40 +29,43 @@ func containsStopSuffix(sequence string, stops []string) bool { // truncateStop removes the provided stop string from pieces, // returning the partial pieces with stop removed, including truncating // the last piece if required (and signalling if this was the case) -func truncateStop(pieces []string, stop string) ([]string, bool) { - joined := strings.Join(pieces, "") +func truncateStop(pieces []CompletionResponse, stop string) ([]CompletionResponse, bool) { + // Build complete string and find stop position + var completeStr string + for _, piece := range pieces { + completeStr += piece.Content + } - index := strings.Index(joined, stop) - if index == -1 { + stopStart := strings.Index(completeStr, stop) + if stopStart == -1 { return pieces, false } - joined = joined[:index] + // Build result up to stop position + result := make([]CompletionResponse, 0) + accumulated := 0 - // Split truncated string back into pieces of original lengths - lengths := make([]int, len(pieces)) - for i, piece := range pieces { - lengths[i] = len(piece) - } - - var result []string - tokenTruncated := false - start := 0 - for _, length := range lengths { - if start >= len(joined) { - break + truncated := false + for _, piece := range pieces { + if accumulated+len(piece.Content) <= stopStart { + result = append(result, piece) + accumulated += len(piece.Content) + continue } - end := start + length - if end > len(joined) { - end = len(joined) - tokenTruncated = true + if accumulated < stopStart { + truncPiece := piece + truncPiece.Content = piece.Content[:stopStart-accumulated] + if len(truncPiece.Content) > 0 { + result = append(result, truncPiece) + truncated = true + } } - result = append(result, joined[start:end]) - start = end + break } - return result, tokenTruncated + // Signal if we had to truncate the last piece + return result, truncated } func incompleteUnicode(token string) bool { diff --git a/llama/runner/stop_test.go b/llama/runner/stop_test.go index 31dc161f3..17174f400 100644 --- a/llama/runner/stop_test.go +++ b/llama/runner/stop_test.go @@ -8,44 +8,74 @@ import ( func TestTruncateStop(t *testing.T) { tests := []struct { name string - pieces []string + pieces []CompletionResponse stop string - expected []string + expected []CompletionResponse expectedTrunc bool }{ { - name: "Single word", - pieces: []string{"hello", "world"}, - stop: "world", - expected: []string{"hello"}, + name: "Single word", + pieces: []CompletionResponse{ + {Content: "hello"}, + {Content: "world"}, + }, + stop: "world", + expected: []CompletionResponse{ + {Content: "hello"}, + }, expectedTrunc: false, }, { - name: "Partial", - pieces: []string{"hello", "wor"}, - stop: "or", - expected: []string{"hello", "w"}, + name: "Partial", + pieces: []CompletionResponse{ + {Content: "hello"}, + {Content: "wor"}, + }, + stop: "or", + expected: []CompletionResponse{ + {Content: "hello"}, + {Content: "w"}, + }, expectedTrunc: true, }, { - name: "Suffix", - pieces: []string{"Hello", " there", "!"}, - stop: "!", - expected: []string{"Hello", " there"}, + name: "Suffix", + pieces: []CompletionResponse{ + {Content: "Hello"}, + {Content: " there"}, + {Content: "!"}, + }, + stop: "!", + expected: []CompletionResponse{ + {Content: "Hello"}, + {Content: " there"}, + }, expectedTrunc: false, }, { - name: "Suffix partial", - pieces: []string{"Hello", " the", "re!"}, - stop: "there!", - expected: []string{"Hello", " "}, + name: "Suffix partial", + pieces: []CompletionResponse{ + {Content: "Hello"}, + {Content: " the"}, + {Content: "re!"}, + }, + stop: "there!", + expected: []CompletionResponse{ + {Content: "Hello"}, + {Content: " "}, + }, expectedTrunc: true, }, { - name: "Middle", - pieces: []string{"hello", " wor"}, - stop: "llo w", - expected: []string{"he"}, + name: "Middle", + pieces: []CompletionResponse{ + {Content: "hello"}, + {Content: " wor"}, + }, + stop: "llo w", + expected: []CompletionResponse{ + {Content: "he"}, + }, expectedTrunc: true, }, }