Compare commits

...

2 Commits

Author SHA1 Message Date
Bruce MacDonald
946fdd5388 update completion responses 2025-03-19 10:00:00 -07:00
Bruce MacDonald
905da35468 runner: enable returning more info from runner processing
Currently we return only the text predicted from the LLM. This was nice in
that it was simple, but there may be other info we want to know from the
processing. This change adds the ability to return more information from the
runner than just the text predicted.

A follow up change will add logprobs to the response returned from the
runner using this structure.
2025-03-19 09:46:07 -07:00
4 changed files with 158 additions and 105 deletions

View File

@ -2,6 +2,8 @@ package common
import ( import (
"strings" "strings"
"github.com/ollama/ollama/llm"
) )
func FindStop(sequence string, stops []string) (bool, string) { func FindStop(sequence string, stops []string) (bool, string) {
@ -29,40 +31,43 @@ func ContainsStopSuffix(sequence string, stops []string) bool {
// truncateStop removes the provided stop string from pieces, // truncateStop removes the provided stop string from pieces,
// returning the partial pieces with stop removed, including truncating // returning the partial pieces with stop removed, including truncating
// the last piece if required (and signalling if this was the case) // the last piece if required (and signalling if this was the case)
func TruncateStop(pieces []string, stop string) ([]string, bool) { func TruncateStop(resps []llm.CompletionResponse, stop string) ([]llm.CompletionResponse, bool) {
joined := strings.Join(pieces, "") var sequence string
for _, resp := range resps {
index := strings.Index(joined, stop) sequence += resp.Content
if index == -1 {
return pieces, false
} }
joined = joined[:index] idx := strings.Index(sequence, stop)
if idx < 0 {
// Split truncated string back into pieces of original lengths return resps, false
lengths := make([]int, len(pieces))
for i, piece := range pieces {
lengths[i] = len(piece)
} }
var result []string truncated := sequence[:idx]
tokenTruncated := false if len(truncated) == 0 {
start := 0 return nil, true
for _, length := range lengths { }
if start >= len(joined) {
result := make([]llm.CompletionResponse, 0, len(resps))
// Track position in truncated sequence
pos := 0
truncationHappened := false
for _, resp := range resps {
if pos >= len(truncated) {
break break
} }
end := start + length chunk := truncated[pos:min(pos+len(resp.Content), len(truncated))]
if end > len(joined) { if len(chunk) < len(resp.Content) {
end = len(joined) truncationHappened = true
tokenTruncated = true
} }
result = append(result, joined[start:end]) if len(chunk) > 0 {
start = end result = append(result, llm.CompletionResponse{Content: chunk})
}
pos += len(resp.Content)
} }
return result, tokenTruncated return result, truncationHappened
} }
func IncompleteUnicode(token string) bool { func IncompleteUnicode(token string) bool {

View File

@ -1,51 +1,84 @@
package common package common
import ( import (
"fmt"
"reflect" "reflect"
"testing" "testing"
"github.com/ollama/ollama/llm"
) )
func TestTruncateStop(t *testing.T) { func TestTruncateStop(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
pieces []string pieces []llm.CompletionResponse
stop string stop string
expected []string expected []llm.CompletionResponse
expectedTrunc bool expectedTrunc bool
}{ }{
{ {
name: "Single word", name: "Single word",
pieces: []string{"hello", "world"}, pieces: []llm.CompletionResponse{
stop: "world", {Content: "Hello"},
expected: []string{"hello"}, {Content: "world"},
},
stop: "world",
expected: []llm.CompletionResponse{
{Content: "Hello"},
},
expectedTrunc: false, expectedTrunc: false,
}, },
{ {
name: "Partial", name: "Partial",
pieces: []string{"hello", "wor"}, pieces: []llm.CompletionResponse{
stop: "or", {Content: "Hello"},
expected: []string{"hello", "w"}, {Content: " wor"},
},
stop: "or",
expected: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " w"},
},
expectedTrunc: true, expectedTrunc: true,
}, },
{ {
name: "Suffix", name: "Suffix",
pieces: []string{"Hello", " there", "!"}, pieces: []llm.CompletionResponse{
stop: "!", {Content: "Hello"},
expected: []string{"Hello", " there"}, {Content: " there"},
{Content: "!"},
},
stop: "!",
expected: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " there"},
},
expectedTrunc: false, expectedTrunc: false,
}, },
{ {
name: "Suffix partial", name: "Suffix partial",
pieces: []string{"Hello", " the", "re!"}, pieces: []llm.CompletionResponse{
stop: "there!", {Content: "Hello"},
expected: []string{"Hello", " "}, {Content: " the"},
{Content: "re!"},
},
stop: "there!",
expected: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " "},
},
expectedTrunc: true, expectedTrunc: true,
}, },
{ {
name: "Middle", name: "Middle",
pieces: []string{"hello", " wor"}, pieces: []llm.CompletionResponse{
stop: "llo w", {Content: "Hello"},
expected: []string{"he"}, {Content: " wo"},
},
stop: "llo w",
expected: []llm.CompletionResponse{
{Content: "He"},
},
expectedTrunc: true, expectedTrunc: true,
}, },
} }
@ -54,12 +87,27 @@ func TestTruncateStop(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result, resultTrunc := TruncateStop(tt.pieces, tt.stop) result, resultTrunc := TruncateStop(tt.pieces, tt.stop)
if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc { if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc) t.Errorf("truncateStop(%v, %v):\n%shave truncated %v\nwant truncated %v",
tt.pieces, tt.stop, formatContentDiff(result, tt.expected), resultTrunc, tt.expectedTrunc)
} }
}) })
} }
} }
func formatContentDiff(result, expected []llm.CompletionResponse) string {
var s string
for i := 0; i < len(result) || i < len(expected); i++ {
if i < len(result) && i < len(expected) && result[i].Content != expected[i].Content {
s += fmt.Sprintf("[%d] %q vs %q\n", i, result[i].Content, expected[i].Content)
} else if i < len(result) && i >= len(expected) {
s += fmt.Sprintf("[%d] extra %q\n", i, result[i].Content)
} else if i >= len(result) && i < len(expected) {
s += fmt.Sprintf("[%d] missing %q\n", i, expected[i].Content)
}
}
return s
}
func TestIncompleteUnicode(t *testing.T) { func TestIncompleteUnicode(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View File

@ -51,7 +51,7 @@ type Sequence struct {
pendingInputs []input pendingInputs []input
// tokens that have been generated but not returned yet (e.g. for stop sequences) // tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string pendingResponses []llm.CompletionResponse
// input cache being used by this sequence // input cache being used by this sequence
cache *InputCacheSlot cache *InputCacheSlot
@ -61,7 +61,7 @@ type Sequence struct {
crossAttention bool crossAttention bool
// channel to send responses over // channel to send responses over
responses chan string responses chan llm.CompletionResponse
// channel to stop decoding (such as if the remote connection is closed) // channel to stop decoding (such as if the remote connection is closed)
quit chan bool quit chan bool
@ -150,8 +150,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
numPromptInputs: len(inputs), numPromptInputs: len(inputs),
startProcessingTime: startTime, startProcessingTime: startTime,
numPredict: params.numPredict, numPredict: params.numPredict,
pendingResponses: make([]string, 0), pendingResponses: make([]llm.CompletionResponse, 0),
responses: make(chan string, 100), responses: make(chan llm.CompletionResponse, 100),
quit: make(chan bool, 1), quit: make(chan bool, 1),
embedding: make(chan []float32, 1), embedding: make(chan []float32, 1),
samplingCtx: sc, samplingCtx: sc,
@ -276,29 +276,28 @@ func (s *Server) allNil() bool {
} }
func flushPending(seq *Sequence) bool { func flushPending(seq *Sequence) bool {
joined := strings.Join(seq.pendingResponses, "") pending := seq.pendingResponses
seq.pendingResponses = []string{} seq.pendingResponses = []llm.CompletionResponse{}
// Check if there are any partial UTF-8 characters remaining. for i, r := range pending {
// We already check and queue as we are generating but some may if i == len(pending)-1 {
// still make it here: // Check and trim any trailing partial UTF-8 characters
// - Sequence is ending, e.g. generation limit has been hit content := r.Content
// - Invalid characters in the middle of a string for !utf8.ValidString(content) {
// This is a stricter check to ensure we never output invalid Unicode. content = content[:len(content)-1]
for !utf8.ValidString(joined) { }
joined = joined[:len(joined)-1] r.Content = content
} }
if len(joined) == 0 { select {
return true case seq.responses <- r:
} return true
case <-seq.quit:
select { return false
case seq.responses <- joined: }
return true
case <-seq.quit:
return false
} }
// no pending responses to send
return true
} }
func (s *Server) removeSequence(seqIndex int, reason string) { func (s *Server) removeSequence(seqIndex int, reason string) {
@ -497,8 +496,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
seq.inputs = []input{{token: token}} seq.inputs = []input{{token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece) seq.pendingResponses = append(seq.pendingResponses, llm.CompletionResponse{Content: piece})
sequence := strings.Join(seq.pendingResponses, "") sequence := ""
for _, r := range seq.pendingResponses {
sequence += r.Content
}
if ok, stop := common.FindStop(sequence, seq.stop); ok { if ok, stop := common.FindStop(sequence, seq.stop); ok {
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop) slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
@ -637,9 +639,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
case content, ok := <-seq.responses: case content, ok := <-seq.responses:
if ok { if ok {
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ if err := json.NewEncoder(w).Encode(&content); err != nil {
Content: content,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
close(seq.quit) close(seq.quit)
return return

View File

@ -53,13 +53,13 @@ type Sequence struct {
pendingInputs []input.Input pendingInputs []input.Input
// tokens that have been generated but not returned yet (e.g. for stop sequences) // tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string pendingResponses []llm.CompletionResponse
// input cache being used by this sequence // input cache being used by this sequence
cache *InputCacheSlot cache *InputCacheSlot
// channel to send responses over // channel to send responses over
responses chan string responses chan llm.CompletionResponse
// channel to stop decoding (such as if the remote connection is closed) // channel to stop decoding (such as if the remote connection is closed)
quit chan bool quit chan bool
@ -138,8 +138,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
numPromptInputs: len(inputs), numPromptInputs: len(inputs),
startProcessingTime: startTime, startProcessingTime: startTime,
numPredict: params.numPredict, numPredict: params.numPredict,
pendingResponses: make([]string, 0), pendingResponses: make([]llm.CompletionResponse, 0),
responses: make(chan string, 100), responses: make(chan llm.CompletionResponse, 100),
quit: make(chan bool, 1), quit: make(chan bool, 1),
embedding: make(chan []float32, 1), embedding: make(chan []float32, 1),
sampler: params.sampler, sampler: params.sampler,
@ -288,29 +288,28 @@ func (s *Server) allNil() bool {
} }
func flushPending(seq *Sequence) bool { func flushPending(seq *Sequence) bool {
joined := strings.Join(seq.pendingResponses, "") pending := seq.pendingResponses
seq.pendingResponses = []string{} seq.pendingResponses = []llm.CompletionResponse{}
// Check if there are any partial UTF-8 characters remaining. for i, r := range pending {
// We already check and queue as we are generating but some may if i == len(pending)-1 {
// still make it here: // Check and trim any trailing partial UTF-8 characters
// - Sequence is ending, e.g. generation limit has been hit content := r.Content
// - Invalid characters in the middle of a string for !utf8.ValidString(content) {
// This is a stricter check to ensure we never output invalid Unicode. content = content[:len(content)-1]
for !utf8.ValidString(joined) { }
joined = joined[:len(joined)-1] r.Content = content
} }
if len(joined) == 0 { select {
return true case seq.responses <- r:
} return true
case <-seq.quit:
select { return false
case seq.responses <- joined: }
return true
case <-seq.quit:
return false
} }
// no pending responses to send
return true
} }
func (s *Server) removeSequence(seqIndex int, reason string) { func (s *Server) removeSequence(seqIndex int, reason string) {
@ -484,8 +483,11 @@ func (s *Server) processBatch() error {
seq.inputs = []input.Input{{Token: token}} seq.inputs = []input.Input{{Token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece) seq.pendingResponses = append(seq.pendingResponses, llm.CompletionResponse{Content: piece})
sequence := strings.Join(seq.pendingResponses, "") sequence := ""
for _, r := range seq.pendingResponses {
sequence += r.Content
}
if ok, stop := common.FindStop(sequence, seq.stop); ok { if ok, stop := common.FindStop(sequence, seq.stop); ok {
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop) slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
@ -623,9 +625,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
case content, ok := <-seq.responses: case content, ok := <-seq.responses:
if ok { if ok {
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ if err := json.NewEncoder(w).Encode(&content); err != nil {
Content: content,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
close(seq.quit) close(seq.quit)
return return