...
This commit is contained in:
parent
6dfcdec2da
commit
64f95067ba
@ -50,8 +50,9 @@ type Sequence struct {
|
|||||||
// inputs that have been added to a batch but not yet submitted to Decode
|
// inputs that have been added to a batch but not yet submitted to Decode
|
||||||
pendingInputs []input
|
pendingInputs []input
|
||||||
|
|
||||||
|
// TODO: update this comment
|
||||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||||
pendingResponses []string
|
pendingResponses []CompletionResponse
|
||||||
|
|
||||||
// input cache being used by this sequence
|
// input cache being used by this sequence
|
||||||
cache *InputCacheSlot
|
cache *InputCacheSlot
|
||||||
@ -87,6 +88,9 @@ type Sequence struct {
|
|||||||
|
|
||||||
logits []float32
|
logits []float32
|
||||||
|
|
||||||
|
// number of logprobs to return with the completion response
|
||||||
|
logprobs int
|
||||||
|
|
||||||
// Metrics
|
// Metrics
|
||||||
startProcessingTime time.Time
|
startProcessingTime time.Time
|
||||||
startGenerationTime time.Time
|
startGenerationTime time.Time
|
||||||
@ -152,7 +156,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
|||||||
numPromptInputs: len(inputs),
|
numPromptInputs: len(inputs),
|
||||||
startProcessingTime: startTime,
|
startProcessingTime: startTime,
|
||||||
numPredict: params.numPredict,
|
numPredict: params.numPredict,
|
||||||
pendingResponses: make([]string, 0),
|
pendingResponses: make([]CompletionResponse, 0),
|
||||||
responses: make(chan CompletionResponse, 100),
|
responses: make(chan CompletionResponse, 100),
|
||||||
quit: make(chan bool, 1),
|
quit: make(chan bool, 1),
|
||||||
embedding: make(chan []float32, 1),
|
embedding: make(chan []float32, 1),
|
||||||
@ -281,8 +285,11 @@ func flushPending(seq *Sequence) bool {
|
|||||||
if len(seq.pendingResponses) == 0 {
|
if len(seq.pendingResponses) == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
content := strings.Join(seq.pendingResponses, "")
|
content := ""
|
||||||
seq.pendingResponses = []string{}
|
for _, resp := range seq.pendingResponses {
|
||||||
|
content += resp.Content
|
||||||
|
}
|
||||||
|
seq.pendingResponses = []CompletionResponse{}
|
||||||
|
|
||||||
// Check if there are any partial UTF-8 characters remaining.
|
// Check if there are any partial UTF-8 characters remaining.
|
||||||
// We already check and queue as we are generating but some may
|
// We already check and queue as we are generating but some may
|
||||||
@ -362,27 +369,27 @@ func (s *Server) run(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TokenData represents probability information for a token
|
// TokenProbs represents probability information for a token
|
||||||
type TokenData struct {
|
type TokenProbs struct {
|
||||||
TokenID int
|
TokenID int
|
||||||
Logit float32
|
Logit float32
|
||||||
Prob float32
|
Prob float32
|
||||||
LogProb float32
|
LogProb float32
|
||||||
}
|
}
|
||||||
|
|
||||||
// getTokenProbabilities returns sorted token probabilities for a specific token index
|
// probs returns sorted token probabilities for a specific token index
|
||||||
func (s *Server) getTokenProbabilities(seq *Sequence) []TokenData {
|
func (s *Server) probs(seq *Sequence) []TokenProbs {
|
||||||
// Get logits for the specific token index
|
// Get logits for the specific token index
|
||||||
logits := s.lc.GetLogits()
|
logits := s.lc.GetLogits()
|
||||||
seq.logits = make([]float32, len(logits))
|
seq.logits = make([]float32, len(logits))
|
||||||
copy(seq.logits, logits)
|
copy(seq.logits, logits)
|
||||||
|
|
||||||
vocabSize := s.model.NumVocab()
|
vocabSize := s.model.NumVocab()
|
||||||
probs := make([]TokenData, vocabSize)
|
probs := make([]TokenProbs, vocabSize)
|
||||||
|
|
||||||
// Initialize token data with logits
|
// Initialize token data with logits
|
||||||
for i := 0; i < vocabSize; i++ {
|
for i := 0; i < vocabSize; i++ {
|
||||||
probs[i] = TokenData{
|
probs[i] = TokenProbs{
|
||||||
TokenID: i,
|
TokenID: i,
|
||||||
Logit: logits[i],
|
Logit: logits[i],
|
||||||
}
|
}
|
||||||
@ -546,10 +553,9 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||||||
|
|
||||||
seq.numPredicted++
|
seq.numPredicted++
|
||||||
|
|
||||||
// TODO: only do this when flag specified
|
if seq.logprobs > 0 {
|
||||||
probs := s.getTokenProbabilities(seq)
|
// TODO: return selected token in logprobs always
|
||||||
for i := range 10 {
|
// probs := s.probs(seq)
|
||||||
slog.Debug("top 10 tokens", "token", probs[i].TokenID, "prob", probs[i].Prob, "logit", probs[i].Logit, "piece", s.model.TokenToPiece(probs[i].TokenID))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// if it's an end of sequence token, break
|
// if it's an end of sequence token, break
|
||||||
@ -564,8 +570,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||||||
|
|
||||||
seq.inputs = []input{{token: token}}
|
seq.inputs = []input{{token: token}}
|
||||||
|
|
||||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
// TODO: add probs here
|
||||||
sequence := strings.Join(seq.pendingResponses, "")
|
seq.pendingResponses = append(seq.pendingResponses, CompletionResponse{Content: piece})
|
||||||
|
var sequence string
|
||||||
|
for _, r := range seq.pendingResponses {
|
||||||
|
sequence += r.Content
|
||||||
|
}
|
||||||
|
|
||||||
if ok, stop := findStop(sequence, seq.stop); ok {
|
if ok, stop := findStop(sequence, seq.stop); ok {
|
||||||
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
||||||
|
@ -29,40 +29,43 @@ func containsStopSuffix(sequence string, stops []string) bool {
|
|||||||
// truncateStop removes the provided stop string from pieces,
|
// truncateStop removes the provided stop string from pieces,
|
||||||
// returning the partial pieces with stop removed, including truncating
|
// returning the partial pieces with stop removed, including truncating
|
||||||
// the last piece if required (and signalling if this was the case)
|
// the last piece if required (and signalling if this was the case)
|
||||||
func truncateStop(pieces []string, stop string) ([]string, bool) {
|
func truncateStop(pieces []CompletionResponse, stop string) ([]CompletionResponse, bool) {
|
||||||
joined := strings.Join(pieces, "")
|
// Build complete string and find stop position
|
||||||
|
var completeStr string
|
||||||
|
for _, piece := range pieces {
|
||||||
|
completeStr += piece.Content
|
||||||
|
}
|
||||||
|
|
||||||
index := strings.Index(joined, stop)
|
stopStart := strings.Index(completeStr, stop)
|
||||||
if index == -1 {
|
if stopStart == -1 {
|
||||||
return pieces, false
|
return pieces, false
|
||||||
}
|
}
|
||||||
|
|
||||||
joined = joined[:index]
|
// Build result up to stop position
|
||||||
|
result := make([]CompletionResponse, 0)
|
||||||
|
accumulated := 0
|
||||||
|
|
||||||
// Split truncated string back into pieces of original lengths
|
truncated := false
|
||||||
lengths := make([]int, len(pieces))
|
for _, piece := range pieces {
|
||||||
for i, piece := range pieces {
|
if accumulated+len(piece.Content) <= stopStart {
|
||||||
lengths[i] = len(piece)
|
result = append(result, piece)
|
||||||
|
accumulated += len(piece.Content)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
var result []string
|
if accumulated < stopStart {
|
||||||
tokenTruncated := false
|
truncPiece := piece
|
||||||
start := 0
|
truncPiece.Content = piece.Content[:stopStart-accumulated]
|
||||||
for _, length := range lengths {
|
if len(truncPiece.Content) > 0 {
|
||||||
if start >= len(joined) {
|
result = append(result, truncPiece)
|
||||||
|
truncated = true
|
||||||
|
}
|
||||||
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
end := start + length
|
// Signal if we had to truncate the last piece
|
||||||
if end > len(joined) {
|
return result, truncated
|
||||||
end = len(joined)
|
|
||||||
tokenTruncated = true
|
|
||||||
}
|
|
||||||
result = append(result, joined[start:end])
|
|
||||||
start = end
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, tokenTruncated
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func incompleteUnicode(token string) bool {
|
func incompleteUnicode(token string) bool {
|
||||||
|
@ -8,44 +8,74 @@ import (
|
|||||||
func TestTruncateStop(t *testing.T) {
|
func TestTruncateStop(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
pieces []string
|
pieces []CompletionResponse
|
||||||
stop string
|
stop string
|
||||||
expected []string
|
expected []CompletionResponse
|
||||||
expectedTrunc bool
|
expectedTrunc bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Single word",
|
name: "Single word",
|
||||||
pieces: []string{"hello", "world"},
|
pieces: []CompletionResponse{
|
||||||
|
{Content: "hello"},
|
||||||
|
{Content: "world"},
|
||||||
|
},
|
||||||
stop: "world",
|
stop: "world",
|
||||||
expected: []string{"hello"},
|
expected: []CompletionResponse{
|
||||||
|
{Content: "hello"},
|
||||||
|
},
|
||||||
expectedTrunc: false,
|
expectedTrunc: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Partial",
|
name: "Partial",
|
||||||
pieces: []string{"hello", "wor"},
|
pieces: []CompletionResponse{
|
||||||
|
{Content: "hello"},
|
||||||
|
{Content: "wor"},
|
||||||
|
},
|
||||||
stop: "or",
|
stop: "or",
|
||||||
expected: []string{"hello", "w"},
|
expected: []CompletionResponse{
|
||||||
|
{Content: "hello"},
|
||||||
|
{Content: "w"},
|
||||||
|
},
|
||||||
expectedTrunc: true,
|
expectedTrunc: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Suffix",
|
name: "Suffix",
|
||||||
pieces: []string{"Hello", " there", "!"},
|
pieces: []CompletionResponse{
|
||||||
|
{Content: "Hello"},
|
||||||
|
{Content: " there"},
|
||||||
|
{Content: "!"},
|
||||||
|
},
|
||||||
stop: "!",
|
stop: "!",
|
||||||
expected: []string{"Hello", " there"},
|
expected: []CompletionResponse{
|
||||||
|
{Content: "Hello"},
|
||||||
|
{Content: " there"},
|
||||||
|
},
|
||||||
expectedTrunc: false,
|
expectedTrunc: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Suffix partial",
|
name: "Suffix partial",
|
||||||
pieces: []string{"Hello", " the", "re!"},
|
pieces: []CompletionResponse{
|
||||||
|
{Content: "Hello"},
|
||||||
|
{Content: " the"},
|
||||||
|
{Content: "re!"},
|
||||||
|
},
|
||||||
stop: "there!",
|
stop: "there!",
|
||||||
expected: []string{"Hello", " "},
|
expected: []CompletionResponse{
|
||||||
|
{Content: "Hello"},
|
||||||
|
{Content: " "},
|
||||||
|
},
|
||||||
expectedTrunc: true,
|
expectedTrunc: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Middle",
|
name: "Middle",
|
||||||
pieces: []string{"hello", " wor"},
|
pieces: []CompletionResponse{
|
||||||
|
{Content: "hello"},
|
||||||
|
{Content: " wor"},
|
||||||
|
},
|
||||||
stop: "llo w",
|
stop: "llo w",
|
||||||
expected: []string{"he"},
|
expected: []CompletionResponse{
|
||||||
|
{Content: "he"},
|
||||||
|
},
|
||||||
expectedTrunc: true,
|
expectedTrunc: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user