runner: enable returning more info from runner processing
Currently we return only the text predicted from the LLM. This was nice in that it was simple, but there may be other info we want to know from the processing. This change adds the ability to return more information from the runner than just the text predicted. A follow up change will add logprobs to the response returned from the runner using this structure.
This commit is contained in:
parent
da0e345200
commit
905da35468
@ -29,40 +29,43 @@ func ContainsStopSuffix(sequence string, stops []string) bool {
|
|||||||
// truncateStop removes the provided stop string from pieces,
|
// truncateStop removes the provided stop string from pieces,
|
||||||
// returning the partial pieces with stop removed, including truncating
|
// returning the partial pieces with stop removed, including truncating
|
||||||
// the last piece if required (and signalling if this was the case)
|
// the last piece if required (and signalling if this was the case)
|
||||||
func TruncateStop(pieces []string, stop string) ([]string, bool) {
|
func TruncateStop(resps []CompletionResponse, stop string) ([]CompletionResponse, bool) {
|
||||||
joined := strings.Join(pieces, "")
|
var sequence string
|
||||||
|
for _, resp := range resps {
|
||||||
index := strings.Index(joined, stop)
|
sequence += resp.Content
|
||||||
if index == -1 {
|
|
||||||
return pieces, false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
joined = joined[:index]
|
idx := strings.Index(sequence, stop)
|
||||||
|
if idx < 0 {
|
||||||
// Split truncated string back into pieces of original lengths
|
return resps, false
|
||||||
lengths := make([]int, len(pieces))
|
|
||||||
for i, piece := range pieces {
|
|
||||||
lengths[i] = len(piece)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var result []string
|
truncated := sequence[:idx]
|
||||||
tokenTruncated := false
|
if len(truncated) == 0 {
|
||||||
start := 0
|
return nil, true
|
||||||
for _, length := range lengths {
|
}
|
||||||
if start >= len(joined) {
|
|
||||||
|
result := make([]CompletionResponse, 0, len(resps))
|
||||||
|
|
||||||
|
// Track position in truncated sequence
|
||||||
|
pos := 0
|
||||||
|
truncationHappened := false
|
||||||
|
for _, resp := range resps {
|
||||||
|
if pos >= len(truncated) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
end := start + length
|
chunk := truncated[pos:min(pos+len(resp.Content), len(truncated))]
|
||||||
if end > len(joined) {
|
if len(chunk) < len(resp.Content) {
|
||||||
end = len(joined)
|
truncationHappened = true
|
||||||
tokenTruncated = true
|
|
||||||
}
|
}
|
||||||
result = append(result, joined[start:end])
|
if len(chunk) > 0 {
|
||||||
start = end
|
result = append(result, CompletionResponse{Content: chunk})
|
||||||
|
}
|
||||||
|
pos += len(resp.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, tokenTruncated
|
return result, truncationHappened
|
||||||
}
|
}
|
||||||
|
|
||||||
func IncompleteUnicode(token string) bool {
|
func IncompleteUnicode(token string) bool {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@ -8,44 +9,74 @@ import (
|
|||||||
func TestTruncateStop(t *testing.T) {
|
func TestTruncateStop(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
pieces []string
|
pieces []CompletionResponse
|
||||||
stop string
|
stop string
|
||||||
expected []string
|
expected []CompletionResponse
|
||||||
expectedTrunc bool
|
expectedTrunc bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Single word",
|
name: "Single word",
|
||||||
pieces: []string{"hello", "world"},
|
pieces: []CompletionResponse{
|
||||||
stop: "world",
|
{Content: "Hello"},
|
||||||
expected: []string{"hello"},
|
{Content: "world"},
|
||||||
|
},
|
||||||
|
stop: "world",
|
||||||
|
expected: []CompletionResponse{
|
||||||
|
{Content: "Hello"},
|
||||||
|
},
|
||||||
expectedTrunc: false,
|
expectedTrunc: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Partial",
|
name: "Partial",
|
||||||
pieces: []string{"hello", "wor"},
|
pieces: []CompletionResponse{
|
||||||
stop: "or",
|
{Content: "Hello"},
|
||||||
expected: []string{"hello", "w"},
|
{Content: " wor"},
|
||||||
|
},
|
||||||
|
stop: "or",
|
||||||
|
expected: []CompletionResponse{
|
||||||
|
{Content: "Hello"},
|
||||||
|
{Content: " w"},
|
||||||
|
},
|
||||||
expectedTrunc: true,
|
expectedTrunc: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Suffix",
|
name: "Suffix",
|
||||||
pieces: []string{"Hello", " there", "!"},
|
pieces: []CompletionResponse{
|
||||||
stop: "!",
|
{Content: "Hello"},
|
||||||
expected: []string{"Hello", " there"},
|
{Content: " there"},
|
||||||
|
{Content: "!"},
|
||||||
|
},
|
||||||
|
stop: "!",
|
||||||
|
expected: []CompletionResponse{
|
||||||
|
{Content: "Hello"},
|
||||||
|
{Content: " there"},
|
||||||
|
},
|
||||||
expectedTrunc: false,
|
expectedTrunc: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Suffix partial",
|
name: "Suffix partial",
|
||||||
pieces: []string{"Hello", " the", "re!"},
|
pieces: []CompletionResponse{
|
||||||
stop: "there!",
|
{Content: "Hello"},
|
||||||
expected: []string{"Hello", " "},
|
{Content: " the"},
|
||||||
|
{Content: "re!"},
|
||||||
|
},
|
||||||
|
stop: "there!",
|
||||||
|
expected: []CompletionResponse{
|
||||||
|
{Content: "Hello"},
|
||||||
|
{Content: " "},
|
||||||
|
},
|
||||||
expectedTrunc: true,
|
expectedTrunc: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Middle",
|
name: "Middle",
|
||||||
pieces: []string{"hello", " wor"},
|
pieces: []CompletionResponse{
|
||||||
stop: "llo w",
|
{Content: "Hello"},
|
||||||
expected: []string{"he"},
|
{Content: " wo"},
|
||||||
|
},
|
||||||
|
stop: "llo w",
|
||||||
|
expected: []CompletionResponse{
|
||||||
|
{Content: "He"},
|
||||||
|
},
|
||||||
expectedTrunc: true,
|
expectedTrunc: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -54,12 +85,27 @@ func TestTruncateStop(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result, resultTrunc := TruncateStop(tt.pieces, tt.stop)
|
result, resultTrunc := TruncateStop(tt.pieces, tt.stop)
|
||||||
if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
|
if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
|
||||||
t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc)
|
t.Errorf("truncateStop(%v, %v):\n%shave truncated %v\nwant truncated %v",
|
||||||
|
tt.pieces, tt.stop, formatContentDiff(result, tt.expected), resultTrunc, tt.expectedTrunc)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func formatContentDiff(result, expected []CompletionResponse) string {
|
||||||
|
var s string
|
||||||
|
for i := 0; i < len(result) || i < len(expected); i++ {
|
||||||
|
if i < len(result) && i < len(expected) && result[i].Content != expected[i].Content {
|
||||||
|
s += fmt.Sprintf("[%d] %q vs %q\n", i, result[i].Content, expected[i].Content)
|
||||||
|
} else if i < len(result) && i >= len(expected) {
|
||||||
|
s += fmt.Sprintf("[%d] extra %q\n", i, result[i].Content)
|
||||||
|
} else if i >= len(result) && i < len(expected) {
|
||||||
|
s += fmt.Sprintf("[%d] missing %q\n", i, expected[i].Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
func TestIncompleteUnicode(t *testing.T) {
|
func TestIncompleteUnicode(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
23
runner/common/types.go
Normal file
23
runner/common/types.go
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
type CompletionResponse struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
Stop bool `json:"stop"`
|
||||||
|
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
StoppedLimit bool `json:"stopped_limit,omitempty"`
|
||||||
|
PredictedN int `json:"predicted_n,omitempty"`
|
||||||
|
PredictedMS float64 `json:"predicted_ms,omitempty"`
|
||||||
|
PromptN int `json:"prompt_n,omitempty"`
|
||||||
|
PromptMS float64 `json:"prompt_ms,omitempty"`
|
||||||
|
|
||||||
|
Timings Timings `json:"timings"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Timings struct {
|
||||||
|
PredictedN int `json:"predicted_n"`
|
||||||
|
PredictedMS float64 `json:"predicted_ms"`
|
||||||
|
PromptN int `json:"prompt_n"`
|
||||||
|
PromptMS float64 `json:"prompt_ms"`
|
||||||
|
}
|
@ -51,7 +51,7 @@ type Sequence struct {
|
|||||||
pendingInputs []input
|
pendingInputs []input
|
||||||
|
|
||||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||||
pendingResponses []string
|
pendingResponses []common.CompletionResponse
|
||||||
|
|
||||||
// input cache being used by this sequence
|
// input cache being used by this sequence
|
||||||
cache *InputCacheSlot
|
cache *InputCacheSlot
|
||||||
@ -61,7 +61,7 @@ type Sequence struct {
|
|||||||
crossAttention bool
|
crossAttention bool
|
||||||
|
|
||||||
// channel to send responses over
|
// channel to send responses over
|
||||||
responses chan string
|
responses chan common.CompletionResponse
|
||||||
|
|
||||||
// channel to stop decoding (such as if the remote connection is closed)
|
// channel to stop decoding (such as if the remote connection is closed)
|
||||||
quit chan bool
|
quit chan bool
|
||||||
@ -150,8 +150,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
|||||||
numPromptInputs: len(inputs),
|
numPromptInputs: len(inputs),
|
||||||
startProcessingTime: startTime,
|
startProcessingTime: startTime,
|
||||||
numPredict: params.numPredict,
|
numPredict: params.numPredict,
|
||||||
pendingResponses: make([]string, 0),
|
pendingResponses: make([]common.CompletionResponse, 0),
|
||||||
responses: make(chan string, 100),
|
responses: make(chan common.CompletionResponse, 100),
|
||||||
quit: make(chan bool, 1),
|
quit: make(chan bool, 1),
|
||||||
embedding: make(chan []float32, 1),
|
embedding: make(chan []float32, 1),
|
||||||
samplingCtx: sc,
|
samplingCtx: sc,
|
||||||
@ -276,29 +276,28 @@ func (s *Server) allNil() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func flushPending(seq *Sequence) bool {
|
func flushPending(seq *Sequence) bool {
|
||||||
joined := strings.Join(seq.pendingResponses, "")
|
pending := seq.pendingResponses
|
||||||
seq.pendingResponses = []string{}
|
seq.pendingResponses = []common.CompletionResponse{}
|
||||||
|
|
||||||
// Check if there are any partial UTF-8 characters remaining.
|
for i, r := range pending {
|
||||||
// We already check and queue as we are generating but some may
|
if i == len(pending)-1 {
|
||||||
// still make it here:
|
// Check and trim any trailing partial UTF-8 characters
|
||||||
// - Sequence is ending, e.g. generation limit has been hit
|
content := r.Content
|
||||||
// - Invalid characters in the middle of a string
|
for !utf8.ValidString(content) {
|
||||||
// This is a stricter check to ensure we never output invalid Unicode.
|
content = content[:len(content)-1]
|
||||||
for !utf8.ValidString(joined) {
|
}
|
||||||
joined = joined[:len(joined)-1]
|
r.Content = content
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(joined) == 0 {
|
select {
|
||||||
return true
|
case seq.responses <- r:
|
||||||
}
|
return true
|
||||||
|
case <-seq.quit:
|
||||||
select {
|
return false
|
||||||
case seq.responses <- joined:
|
}
|
||||||
return true
|
|
||||||
case <-seq.quit:
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
// no pending responses to send
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) removeSequence(seqIndex int, reason string) {
|
func (s *Server) removeSequence(seqIndex int, reason string) {
|
||||||
@ -497,8 +496,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||||||
|
|
||||||
seq.inputs = []input{{token: token}}
|
seq.inputs = []input{{token: token}}
|
||||||
|
|
||||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
seq.pendingResponses = append(seq.pendingResponses, common.CompletionResponse{Content: piece})
|
||||||
sequence := strings.Join(seq.pendingResponses, "")
|
sequence := ""
|
||||||
|
for _, r := range seq.pendingResponses {
|
||||||
|
sequence += r.Content
|
||||||
|
}
|
||||||
|
|
||||||
if ok, stop := common.FindStop(sequence, seq.stop); ok {
|
if ok, stop := common.FindStop(sequence, seq.stop); ok {
|
||||||
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
||||||
|
@ -53,13 +53,13 @@ type Sequence struct {
|
|||||||
pendingInputs []input.Input
|
pendingInputs []input.Input
|
||||||
|
|
||||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||||
pendingResponses []string
|
pendingResponses []common.CompletionResponse
|
||||||
|
|
||||||
// input cache being used by this sequence
|
// input cache being used by this sequence
|
||||||
cache *InputCacheSlot
|
cache *InputCacheSlot
|
||||||
|
|
||||||
// channel to send responses over
|
// channel to send responses over
|
||||||
responses chan string
|
responses chan common.CompletionResponse
|
||||||
|
|
||||||
// channel to stop decoding (such as if the remote connection is closed)
|
// channel to stop decoding (such as if the remote connection is closed)
|
||||||
quit chan bool
|
quit chan bool
|
||||||
@ -138,8 +138,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
|||||||
numPromptInputs: len(inputs),
|
numPromptInputs: len(inputs),
|
||||||
startProcessingTime: startTime,
|
startProcessingTime: startTime,
|
||||||
numPredict: params.numPredict,
|
numPredict: params.numPredict,
|
||||||
pendingResponses: make([]string, 0),
|
pendingResponses: make([]common.CompletionResponse, 0),
|
||||||
responses: make(chan string, 100),
|
responses: make(chan common.CompletionResponse, 100),
|
||||||
quit: make(chan bool, 1),
|
quit: make(chan bool, 1),
|
||||||
embedding: make(chan []float32, 1),
|
embedding: make(chan []float32, 1),
|
||||||
sampler: params.sampler,
|
sampler: params.sampler,
|
||||||
@ -288,29 +288,28 @@ func (s *Server) allNil() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func flushPending(seq *Sequence) bool {
|
func flushPending(seq *Sequence) bool {
|
||||||
joined := strings.Join(seq.pendingResponses, "")
|
pending := seq.pendingResponses
|
||||||
seq.pendingResponses = []string{}
|
seq.pendingResponses = []common.CompletionResponse{}
|
||||||
|
|
||||||
// Check if there are any partial UTF-8 characters remaining.
|
for i, r := range pending {
|
||||||
// We already check and queue as we are generating but some may
|
if i == len(pending)-1 {
|
||||||
// still make it here:
|
// Check and trim any trailing partial UTF-8 characters
|
||||||
// - Sequence is ending, e.g. generation limit has been hit
|
content := r.Content
|
||||||
// - Invalid characters in the middle of a string
|
for !utf8.ValidString(content) {
|
||||||
// This is a stricter check to ensure we never output invalid Unicode.
|
content = content[:len(content)-1]
|
||||||
for !utf8.ValidString(joined) {
|
}
|
||||||
joined = joined[:len(joined)-1]
|
r.Content = content
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(joined) == 0 {
|
select {
|
||||||
return true
|
case seq.responses <- r:
|
||||||
}
|
return true
|
||||||
|
case <-seq.quit:
|
||||||
select {
|
return false
|
||||||
case seq.responses <- joined:
|
}
|
||||||
return true
|
|
||||||
case <-seq.quit:
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
// no pending responses to send
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) removeSequence(seqIndex int, reason string) {
|
func (s *Server) removeSequence(seqIndex int, reason string) {
|
||||||
@ -484,8 +483,11 @@ func (s *Server) processBatch() error {
|
|||||||
|
|
||||||
seq.inputs = []input.Input{{Token: token}}
|
seq.inputs = []input.Input{{Token: token}}
|
||||||
|
|
||||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
seq.pendingResponses = append(seq.pendingResponses, common.CompletionResponse{Content: piece})
|
||||||
sequence := strings.Join(seq.pendingResponses, "")
|
sequence := ""
|
||||||
|
for _, r := range seq.pendingResponses {
|
||||||
|
sequence += r.Content
|
||||||
|
}
|
||||||
|
|
||||||
if ok, stop := common.FindStop(sequence, seq.stop); ok {
|
if ok, stop := common.FindStop(sequence, seq.stop); ok {
|
||||||
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user