runner: enable returning more info from runner processing

Currently we return only the text predicted from the LLM. This was nice in
that it was simple, but there may be other info we want to know from the
processing. This change adds the ability to return more information from the
runner than just the text predicted.

A follow up change will add logprobs to the response returned from the
runner using this structure.
This commit is contained in:
Bruce MacDonald 2025-02-21 16:31:31 -08:00
parent da0e345200
commit 905da35468
5 changed files with 175 additions and 99 deletions

View File

@ -29,40 +29,43 @@ func ContainsStopSuffix(sequence string, stops []string) bool {
// truncateStop removes the provided stop string from pieces, // truncateStop removes the provided stop string from pieces,
// returning the partial pieces with stop removed, including truncating // returning the partial pieces with stop removed, including truncating
// the last piece if required (and signalling if this was the case) // the last piece if required (and signalling if this was the case)
func TruncateStop(pieces []string, stop string) ([]string, bool) { func TruncateStop(resps []CompletionResponse, stop string) ([]CompletionResponse, bool) {
joined := strings.Join(pieces, "") var sequence string
for _, resp := range resps {
index := strings.Index(joined, stop) sequence += resp.Content
if index == -1 {
return pieces, false
} }
joined = joined[:index] idx := strings.Index(sequence, stop)
if idx < 0 {
// Split truncated string back into pieces of original lengths return resps, false
lengths := make([]int, len(pieces))
for i, piece := range pieces {
lengths[i] = len(piece)
} }
var result []string truncated := sequence[:idx]
tokenTruncated := false if len(truncated) == 0 {
start := 0 return nil, true
for _, length := range lengths { }
if start >= len(joined) {
result := make([]CompletionResponse, 0, len(resps))
// Track position in truncated sequence
pos := 0
truncationHappened := false
for _, resp := range resps {
if pos >= len(truncated) {
break break
} }
end := start + length chunk := truncated[pos:min(pos+len(resp.Content), len(truncated))]
if end > len(joined) { if len(chunk) < len(resp.Content) {
end = len(joined) truncationHappened = true
tokenTruncated = true
} }
result = append(result, joined[start:end]) if len(chunk) > 0 {
start = end result = append(result, CompletionResponse{Content: chunk})
}
pos += len(resp.Content)
} }
return result, tokenTruncated return result, truncationHappened
} }
func IncompleteUnicode(token string) bool { func IncompleteUnicode(token string) bool {

View File

@ -1,6 +1,7 @@
package common package common
import ( import (
"fmt"
"reflect" "reflect"
"testing" "testing"
) )
@ -8,44 +9,74 @@ import (
func TestTruncateStop(t *testing.T) { func TestTruncateStop(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
pieces []string pieces []CompletionResponse
stop string stop string
expected []string expected []CompletionResponse
expectedTrunc bool expectedTrunc bool
}{ }{
{ {
name: "Single word", name: "Single word",
pieces: []string{"hello", "world"}, pieces: []CompletionResponse{
stop: "world", {Content: "Hello"},
expected: []string{"hello"}, {Content: "world"},
},
stop: "world",
expected: []CompletionResponse{
{Content: "Hello"},
},
expectedTrunc: false, expectedTrunc: false,
}, },
{ {
name: "Partial", name: "Partial",
pieces: []string{"hello", "wor"}, pieces: []CompletionResponse{
stop: "or", {Content: "Hello"},
expected: []string{"hello", "w"}, {Content: " wor"},
},
stop: "or",
expected: []CompletionResponse{
{Content: "Hello"},
{Content: " w"},
},
expectedTrunc: true, expectedTrunc: true,
}, },
{ {
name: "Suffix", name: "Suffix",
pieces: []string{"Hello", " there", "!"}, pieces: []CompletionResponse{
stop: "!", {Content: "Hello"},
expected: []string{"Hello", " there"}, {Content: " there"},
{Content: "!"},
},
stop: "!",
expected: []CompletionResponse{
{Content: "Hello"},
{Content: " there"},
},
expectedTrunc: false, expectedTrunc: false,
}, },
{ {
name: "Suffix partial", name: "Suffix partial",
pieces: []string{"Hello", " the", "re!"}, pieces: []CompletionResponse{
stop: "there!", {Content: "Hello"},
expected: []string{"Hello", " "}, {Content: " the"},
{Content: "re!"},
},
stop: "there!",
expected: []CompletionResponse{
{Content: "Hello"},
{Content: " "},
},
expectedTrunc: true, expectedTrunc: true,
}, },
{ {
name: "Middle", name: "Middle",
pieces: []string{"hello", " wor"}, pieces: []CompletionResponse{
stop: "llo w", {Content: "Hello"},
expected: []string{"he"}, {Content: " wo"},
},
stop: "llo w",
expected: []CompletionResponse{
{Content: "He"},
},
expectedTrunc: true, expectedTrunc: true,
}, },
} }
@ -54,12 +85,27 @@ func TestTruncateStop(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result, resultTrunc := TruncateStop(tt.pieces, tt.stop) result, resultTrunc := TruncateStop(tt.pieces, tt.stop)
if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc { if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc) t.Errorf("truncateStop(%v, %v):\n%shave truncated %v\nwant truncated %v",
tt.pieces, tt.stop, formatContentDiff(result, tt.expected), resultTrunc, tt.expectedTrunc)
} }
}) })
} }
} }
func formatContentDiff(result, expected []CompletionResponse) string {
var s string
for i := 0; i < len(result) || i < len(expected); i++ {
if i < len(result) && i < len(expected) && result[i].Content != expected[i].Content {
s += fmt.Sprintf("[%d] %q vs %q\n", i, result[i].Content, expected[i].Content)
} else if i < len(result) && i >= len(expected) {
s += fmt.Sprintf("[%d] extra %q\n", i, result[i].Content)
} else if i >= len(result) && i < len(expected) {
s += fmt.Sprintf("[%d] missing %q\n", i, expected[i].Content)
}
}
return s
}
func TestIncompleteUnicode(t *testing.T) { func TestIncompleteUnicode(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

23
runner/common/types.go Normal file
View File

@ -0,0 +1,23 @@
package common
type CompletionResponse struct {
Content string `json:"content"`
Stop bool `json:"stop"`
Model string `json:"model,omitempty"`
Prompt string `json:"prompt,omitempty"`
StoppedLimit bool `json:"stopped_limit,omitempty"`
PredictedN int `json:"predicted_n,omitempty"`
PredictedMS float64 `json:"predicted_ms,omitempty"`
PromptN int `json:"prompt_n,omitempty"`
PromptMS float64 `json:"prompt_ms,omitempty"`
Timings Timings `json:"timings"`
}
type Timings struct {
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
}

View File

@ -51,7 +51,7 @@ type Sequence struct {
pendingInputs []input pendingInputs []input
// tokens that have been generated but not returned yet (e.g. for stop sequences) // tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string pendingResponses []common.CompletionResponse
// input cache being used by this sequence // input cache being used by this sequence
cache *InputCacheSlot cache *InputCacheSlot
@ -61,7 +61,7 @@ type Sequence struct {
crossAttention bool crossAttention bool
// channel to send responses over // channel to send responses over
responses chan string responses chan common.CompletionResponse
// channel to stop decoding (such as if the remote connection is closed) // channel to stop decoding (such as if the remote connection is closed)
quit chan bool quit chan bool
@ -150,8 +150,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
numPromptInputs: len(inputs), numPromptInputs: len(inputs),
startProcessingTime: startTime, startProcessingTime: startTime,
numPredict: params.numPredict, numPredict: params.numPredict,
pendingResponses: make([]string, 0), pendingResponses: make([]common.CompletionResponse, 0),
responses: make(chan string, 100), responses: make(chan common.CompletionResponse, 100),
quit: make(chan bool, 1), quit: make(chan bool, 1),
embedding: make(chan []float32, 1), embedding: make(chan []float32, 1),
samplingCtx: sc, samplingCtx: sc,
@ -276,29 +276,28 @@ func (s *Server) allNil() bool {
} }
func flushPending(seq *Sequence) bool { func flushPending(seq *Sequence) bool {
joined := strings.Join(seq.pendingResponses, "") pending := seq.pendingResponses
seq.pendingResponses = []string{} seq.pendingResponses = []common.CompletionResponse{}
// Check if there are any partial UTF-8 characters remaining. for i, r := range pending {
// We already check and queue as we are generating but some may if i == len(pending)-1 {
// still make it here: // Check and trim any trailing partial UTF-8 characters
// - Sequence is ending, e.g. generation limit has been hit content := r.Content
// - Invalid characters in the middle of a string for !utf8.ValidString(content) {
// This is a stricter check to ensure we never output invalid Unicode. content = content[:len(content)-1]
for !utf8.ValidString(joined) { }
joined = joined[:len(joined)-1] r.Content = content
} }
if len(joined) == 0 { select {
return true case seq.responses <- r:
} return true
case <-seq.quit:
select { return false
case seq.responses <- joined: }
return true
case <-seq.quit:
return false
} }
// no pending responses to send
return true
} }
func (s *Server) removeSequence(seqIndex int, reason string) { func (s *Server) removeSequence(seqIndex int, reason string) {
@ -497,8 +496,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
seq.inputs = []input{{token: token}} seq.inputs = []input{{token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece) seq.pendingResponses = append(seq.pendingResponses, common.CompletionResponse{Content: piece})
sequence := strings.Join(seq.pendingResponses, "") sequence := ""
for _, r := range seq.pendingResponses {
sequence += r.Content
}
if ok, stop := common.FindStop(sequence, seq.stop); ok { if ok, stop := common.FindStop(sequence, seq.stop); ok {
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop) slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)

View File

@ -53,13 +53,13 @@ type Sequence struct {
pendingInputs []input.Input pendingInputs []input.Input
// tokens that have been generated but not returned yet (e.g. for stop sequences) // tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string pendingResponses []common.CompletionResponse
// input cache being used by this sequence // input cache being used by this sequence
cache *InputCacheSlot cache *InputCacheSlot
// channel to send responses over // channel to send responses over
responses chan string responses chan common.CompletionResponse
// channel to stop decoding (such as if the remote connection is closed) // channel to stop decoding (such as if the remote connection is closed)
quit chan bool quit chan bool
@ -138,8 +138,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
numPromptInputs: len(inputs), numPromptInputs: len(inputs),
startProcessingTime: startTime, startProcessingTime: startTime,
numPredict: params.numPredict, numPredict: params.numPredict,
pendingResponses: make([]string, 0), pendingResponses: make([]common.CompletionResponse, 0),
responses: make(chan string, 100), responses: make(chan common.CompletionResponse, 100),
quit: make(chan bool, 1), quit: make(chan bool, 1),
embedding: make(chan []float32, 1), embedding: make(chan []float32, 1),
sampler: params.sampler, sampler: params.sampler,
@ -288,29 +288,28 @@ func (s *Server) allNil() bool {
} }
func flushPending(seq *Sequence) bool { func flushPending(seq *Sequence) bool {
joined := strings.Join(seq.pendingResponses, "") pending := seq.pendingResponses
seq.pendingResponses = []string{} seq.pendingResponses = []common.CompletionResponse{}
// Check if there are any partial UTF-8 characters remaining. for i, r := range pending {
// We already check and queue as we are generating but some may if i == len(pending)-1 {
// still make it here: // Check and trim any trailing partial UTF-8 characters
// - Sequence is ending, e.g. generation limit has been hit content := r.Content
// - Invalid characters in the middle of a string for !utf8.ValidString(content) {
// This is a stricter check to ensure we never output invalid Unicode. content = content[:len(content)-1]
for !utf8.ValidString(joined) { }
joined = joined[:len(joined)-1] r.Content = content
} }
if len(joined) == 0 { select {
return true case seq.responses <- r:
} return true
case <-seq.quit:
select { return false
case seq.responses <- joined: }
return true
case <-seq.quit:
return false
} }
// no pending responses to send
return true
} }
func (s *Server) removeSequence(seqIndex int, reason string) { func (s *Server) removeSequence(seqIndex int, reason string) {
@ -484,8 +483,11 @@ func (s *Server) processBatch() error {
seq.inputs = []input.Input{{Token: token}} seq.inputs = []input.Input{{Token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece) seq.pendingResponses = append(seq.pendingResponses, common.CompletionResponse{Content: piece})
sequence := strings.Join(seq.pendingResponses, "") sequence := ""
for _, r := range seq.pendingResponses {
sequence += r.Content
}
if ok, stop := common.FindStop(sequence, seq.stop); ok { if ok, stop := common.FindStop(sequence, seq.stop); ok {
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop) slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)