runner: enable returning more info from runner processing

Currently we return only the text predicted from the LLM. This was nice in
that it was simple, but there may be other info we want to know from the
processing. This change adds the ability to return more information from the
runner than just the text predicted.

A follow up change will add logprobs to the response returned from the
runner using this structure.
This commit is contained in:
Bruce MacDonald 2025-02-21 16:31:31 -08:00
parent da0e345200
commit 905da35468
5 changed files with 175 additions and 99 deletions

View File

@ -29,40 +29,43 @@ func ContainsStopSuffix(sequence string, stops []string) bool {
// truncateStop removes the provided stop string from pieces,
// returning the partial pieces with stop removed, including truncating
// the last piece if required (and signalling if this was the case)
func TruncateStop(pieces []string, stop string) ([]string, bool) {
joined := strings.Join(pieces, "")
index := strings.Index(joined, stop)
if index == -1 {
return pieces, false
func TruncateStop(resps []CompletionResponse, stop string) ([]CompletionResponse, bool) {
var sequence string
for _, resp := range resps {
sequence += resp.Content
}
joined = joined[:index]
// Split truncated string back into pieces of original lengths
lengths := make([]int, len(pieces))
for i, piece := range pieces {
lengths[i] = len(piece)
idx := strings.Index(sequence, stop)
if idx < 0 {
return resps, false
}
var result []string
tokenTruncated := false
start := 0
for _, length := range lengths {
if start >= len(joined) {
truncated := sequence[:idx]
if len(truncated) == 0 {
return nil, true
}
result := make([]CompletionResponse, 0, len(resps))
// Track position in truncated sequence
pos := 0
truncationHappened := false
for _, resp := range resps {
if pos >= len(truncated) {
break
}
end := start + length
if end > len(joined) {
end = len(joined)
tokenTruncated = true
chunk := truncated[pos:min(pos+len(resp.Content), len(truncated))]
if len(chunk) < len(resp.Content) {
truncationHappened = true
}
result = append(result, joined[start:end])
start = end
if len(chunk) > 0 {
result = append(result, CompletionResponse{Content: chunk})
}
pos += len(resp.Content)
}
return result, tokenTruncated
return result, truncationHappened
}
func IncompleteUnicode(token string) bool {

View File

@ -1,6 +1,7 @@
package common
import (
"fmt"
"reflect"
"testing"
)
@ -8,44 +9,74 @@ import (
func TestTruncateStop(t *testing.T) {
tests := []struct {
name string
pieces []string
pieces []CompletionResponse
stop string
expected []string
expected []CompletionResponse
expectedTrunc bool
}{
{
name: "Single word",
pieces: []string{"hello", "world"},
stop: "world",
expected: []string{"hello"},
name: "Single word",
pieces: []CompletionResponse{
{Content: "Hello"},
{Content: "world"},
},
stop: "world",
expected: []CompletionResponse{
{Content: "Hello"},
},
expectedTrunc: false,
},
{
name: "Partial",
pieces: []string{"hello", "wor"},
stop: "or",
expected: []string{"hello", "w"},
name: "Partial",
pieces: []CompletionResponse{
{Content: "Hello"},
{Content: " wor"},
},
stop: "or",
expected: []CompletionResponse{
{Content: "Hello"},
{Content: " w"},
},
expectedTrunc: true,
},
{
name: "Suffix",
pieces: []string{"Hello", " there", "!"},
stop: "!",
expected: []string{"Hello", " there"},
name: "Suffix",
pieces: []CompletionResponse{
{Content: "Hello"},
{Content: " there"},
{Content: "!"},
},
stop: "!",
expected: []CompletionResponse{
{Content: "Hello"},
{Content: " there"},
},
expectedTrunc: false,
},
{
name: "Suffix partial",
pieces: []string{"Hello", " the", "re!"},
stop: "there!",
expected: []string{"Hello", " "},
name: "Suffix partial",
pieces: []CompletionResponse{
{Content: "Hello"},
{Content: " the"},
{Content: "re!"},
},
stop: "there!",
expected: []CompletionResponse{
{Content: "Hello"},
{Content: " "},
},
expectedTrunc: true,
},
{
name: "Middle",
pieces: []string{"hello", " wor"},
stop: "llo w",
expected: []string{"he"},
name: "Middle",
pieces: []CompletionResponse{
{Content: "Hello"},
{Content: " wo"},
},
stop: "llo w",
expected: []CompletionResponse{
{Content: "He"},
},
expectedTrunc: true,
},
}
@ -54,12 +85,27 @@ func TestTruncateStop(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
result, resultTrunc := TruncateStop(tt.pieces, tt.stop)
if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc)
t.Errorf("truncateStop(%v, %v):\n%shave truncated %v\nwant truncated %v",
tt.pieces, tt.stop, formatContentDiff(result, tt.expected), resultTrunc, tt.expectedTrunc)
}
})
}
}
func formatContentDiff(result, expected []CompletionResponse) string {
var s string
for i := 0; i < len(result) || i < len(expected); i++ {
if i < len(result) && i < len(expected) && result[i].Content != expected[i].Content {
s += fmt.Sprintf("[%d] %q vs %q\n", i, result[i].Content, expected[i].Content)
} else if i < len(result) && i >= len(expected) {
s += fmt.Sprintf("[%d] extra %q\n", i, result[i].Content)
} else if i >= len(result) && i < len(expected) {
s += fmt.Sprintf("[%d] missing %q\n", i, expected[i].Content)
}
}
return s
}
func TestIncompleteUnicode(t *testing.T) {
tests := []struct {
name string

23
runner/common/types.go Normal file
View File

@ -0,0 +1,23 @@
package common
type CompletionResponse struct {
Content string `json:"content"`
Stop bool `json:"stop"`
Model string `json:"model,omitempty"`
Prompt string `json:"prompt,omitempty"`
StoppedLimit bool `json:"stopped_limit,omitempty"`
PredictedN int `json:"predicted_n,omitempty"`
PredictedMS float64 `json:"predicted_ms,omitempty"`
PromptN int `json:"prompt_n,omitempty"`
PromptMS float64 `json:"prompt_ms,omitempty"`
Timings Timings `json:"timings"`
}
type Timings struct {
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
}

View File

@ -51,7 +51,7 @@ type Sequence struct {
pendingInputs []input
// tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string
pendingResponses []common.CompletionResponse
// input cache being used by this sequence
cache *InputCacheSlot
@ -61,7 +61,7 @@ type Sequence struct {
crossAttention bool
// channel to send responses over
responses chan string
responses chan common.CompletionResponse
// channel to stop decoding (such as if the remote connection is closed)
quit chan bool
@ -150,8 +150,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
numPromptInputs: len(inputs),
startProcessingTime: startTime,
numPredict: params.numPredict,
pendingResponses: make([]string, 0),
responses: make(chan string, 100),
pendingResponses: make([]common.CompletionResponse, 0),
responses: make(chan common.CompletionResponse, 100),
quit: make(chan bool, 1),
embedding: make(chan []float32, 1),
samplingCtx: sc,
@ -276,29 +276,28 @@ func (s *Server) allNil() bool {
}
func flushPending(seq *Sequence) bool {
joined := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = []string{}
pending := seq.pendingResponses
seq.pendingResponses = []common.CompletionResponse{}
// Check if there are any partial UTF-8 characters remaining.
// We already check and queue as we are generating but some may
// still make it here:
// - Sequence is ending, e.g. generation limit has been hit
// - Invalid characters in the middle of a string
// This is a stricter check to ensure we never output invalid Unicode.
for !utf8.ValidString(joined) {
joined = joined[:len(joined)-1]
}
for i, r := range pending {
if i == len(pending)-1 {
// Check and trim any trailing partial UTF-8 characters
content := r.Content
for !utf8.ValidString(content) {
content = content[:len(content)-1]
}
r.Content = content
}
if len(joined) == 0 {
return true
}
select {
case seq.responses <- joined:
return true
case <-seq.quit:
return false
select {
case seq.responses <- r:
return true
case <-seq.quit:
return false
}
}
// no pending responses to send
return true
}
func (s *Server) removeSequence(seqIndex int, reason string) {
@ -497,8 +496,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
seq.inputs = []input{{token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = append(seq.pendingResponses, common.CompletionResponse{Content: piece})
sequence := ""
for _, r := range seq.pendingResponses {
sequence += r.Content
}
if ok, stop := common.FindStop(sequence, seq.stop); ok {
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)

View File

@ -53,13 +53,13 @@ type Sequence struct {
pendingInputs []input.Input
// tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string
pendingResponses []common.CompletionResponse
// input cache being used by this sequence
cache *InputCacheSlot
// channel to send responses over
responses chan string
responses chan common.CompletionResponse
// channel to stop decoding (such as if the remote connection is closed)
quit chan bool
@ -138,8 +138,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
numPromptInputs: len(inputs),
startProcessingTime: startTime,
numPredict: params.numPredict,
pendingResponses: make([]string, 0),
responses: make(chan string, 100),
pendingResponses: make([]common.CompletionResponse, 0),
responses: make(chan common.CompletionResponse, 100),
quit: make(chan bool, 1),
embedding: make(chan []float32, 1),
sampler: params.sampler,
@ -288,29 +288,28 @@ func (s *Server) allNil() bool {
}
func flushPending(seq *Sequence) bool {
joined := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = []string{}
pending := seq.pendingResponses
seq.pendingResponses = []common.CompletionResponse{}
// Check if there are any partial UTF-8 characters remaining.
// We already check and queue as we are generating but some may
// still make it here:
// - Sequence is ending, e.g. generation limit has been hit
// - Invalid characters in the middle of a string
// This is a stricter check to ensure we never output invalid Unicode.
for !utf8.ValidString(joined) {
joined = joined[:len(joined)-1]
}
for i, r := range pending {
if i == len(pending)-1 {
// Check and trim any trailing partial UTF-8 characters
content := r.Content
for !utf8.ValidString(content) {
content = content[:len(content)-1]
}
r.Content = content
}
if len(joined) == 0 {
return true
}
select {
case seq.responses <- joined:
return true
case <-seq.quit:
return false
select {
case seq.responses <- r:
return true
case <-seq.quit:
return false
}
}
// no pending responses to send
return true
}
func (s *Server) removeSequence(seqIndex int, reason string) {
@ -484,8 +483,11 @@ func (s *Server) processBatch() error {
seq.inputs = []input.Input{{Token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = append(seq.pendingResponses, common.CompletionResponse{Content: piece})
sequence := ""
for _, r := range seq.pendingResponses {
sequence += r.Content
}
if ok, stop := common.FindStop(sequence, seq.stop); ok {
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)