Compare commits

...

2 Commits

Author SHA1 Message Date
Bruce MacDonald
946fdd5388 update completion responses 2025-03-19 10:00:00 -07:00
Bruce MacDonald
905da35468 runner: enable returning more info from runner processing
Currently we return only the text predicted from the LLM. This was nice in
that it was simple, but there may be other info we want to know from the
processing. This change adds the ability to return more information from the
runner than just the text predicted.

A follow up change will add logprobs to the response returned from the
runner using this structure.
2025-03-19 09:46:07 -07:00
4 changed files with 158 additions and 105 deletions

View File

@ -2,6 +2,8 @@ package common
import (
"strings"
"github.com/ollama/ollama/llm"
)
func FindStop(sequence string, stops []string) (bool, string) {
@ -29,40 +31,43 @@ func ContainsStopSuffix(sequence string, stops []string) bool {
// truncateStop removes the provided stop string from pieces,
// returning the partial pieces with stop removed, including truncating
// the last piece if required (and signalling if this was the case)
func TruncateStop(pieces []string, stop string) ([]string, bool) {
joined := strings.Join(pieces, "")
index := strings.Index(joined, stop)
if index == -1 {
return pieces, false
func TruncateStop(resps []llm.CompletionResponse, stop string) ([]llm.CompletionResponse, bool) {
var sequence string
for _, resp := range resps {
sequence += resp.Content
}
joined = joined[:index]
// Split truncated string back into pieces of original lengths
lengths := make([]int, len(pieces))
for i, piece := range pieces {
lengths[i] = len(piece)
idx := strings.Index(sequence, stop)
if idx < 0 {
return resps, false
}
var result []string
tokenTruncated := false
start := 0
for _, length := range lengths {
if start >= len(joined) {
truncated := sequence[:idx]
if len(truncated) == 0 {
return nil, true
}
result := make([]llm.CompletionResponse, 0, len(resps))
// Track position in truncated sequence
pos := 0
truncationHappened := false
for _, resp := range resps {
if pos >= len(truncated) {
break
}
end := start + length
if end > len(joined) {
end = len(joined)
tokenTruncated = true
chunk := truncated[pos:min(pos+len(resp.Content), len(truncated))]
if len(chunk) < len(resp.Content) {
truncationHappened = true
}
result = append(result, joined[start:end])
start = end
if len(chunk) > 0 {
result = append(result, llm.CompletionResponse{Content: chunk})
}
pos += len(resp.Content)
}
return result, tokenTruncated
return result, truncationHappened
}
func IncompleteUnicode(token string) bool {

View File

@ -1,51 +1,84 @@
package common
import (
"fmt"
"reflect"
"testing"
"github.com/ollama/ollama/llm"
)
func TestTruncateStop(t *testing.T) {
tests := []struct {
name string
pieces []string
pieces []llm.CompletionResponse
stop string
expected []string
expected []llm.CompletionResponse
expectedTrunc bool
}{
{
name: "Single word",
pieces: []string{"hello", "world"},
stop: "world",
expected: []string{"hello"},
name: "Single word",
pieces: []llm.CompletionResponse{
{Content: "Hello"},
{Content: "world"},
},
stop: "world",
expected: []llm.CompletionResponse{
{Content: "Hello"},
},
expectedTrunc: false,
},
{
name: "Partial",
pieces: []string{"hello", "wor"},
stop: "or",
expected: []string{"hello", "w"},
name: "Partial",
pieces: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " wor"},
},
stop: "or",
expected: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " w"},
},
expectedTrunc: true,
},
{
name: "Suffix",
pieces: []string{"Hello", " there", "!"},
stop: "!",
expected: []string{"Hello", " there"},
name: "Suffix",
pieces: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " there"},
{Content: "!"},
},
stop: "!",
expected: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " there"},
},
expectedTrunc: false,
},
{
name: "Suffix partial",
pieces: []string{"Hello", " the", "re!"},
stop: "there!",
expected: []string{"Hello", " "},
name: "Suffix partial",
pieces: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " the"},
{Content: "re!"},
},
stop: "there!",
expected: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " "},
},
expectedTrunc: true,
},
{
name: "Middle",
pieces: []string{"hello", " wor"},
stop: "llo w",
expected: []string{"he"},
name: "Middle",
pieces: []llm.CompletionResponse{
{Content: "Hello"},
{Content: " wo"},
},
stop: "llo w",
expected: []llm.CompletionResponse{
{Content: "He"},
},
expectedTrunc: true,
},
}
@ -54,12 +87,27 @@ func TestTruncateStop(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
result, resultTrunc := TruncateStop(tt.pieces, tt.stop)
if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc)
t.Errorf("truncateStop(%v, %v):\n%shave truncated %v\nwant truncated %v",
tt.pieces, tt.stop, formatContentDiff(result, tt.expected), resultTrunc, tt.expectedTrunc)
}
})
}
}
func formatContentDiff(result, expected []llm.CompletionResponse) string {
var s string
for i := 0; i < len(result) || i < len(expected); i++ {
if i < len(result) && i < len(expected) && result[i].Content != expected[i].Content {
s += fmt.Sprintf("[%d] %q vs %q\n", i, result[i].Content, expected[i].Content)
} else if i < len(result) && i >= len(expected) {
s += fmt.Sprintf("[%d] extra %q\n", i, result[i].Content)
} else if i >= len(result) && i < len(expected) {
s += fmt.Sprintf("[%d] missing %q\n", i, expected[i].Content)
}
}
return s
}
func TestIncompleteUnicode(t *testing.T) {
tests := []struct {
name string

View File

@ -51,7 +51,7 @@ type Sequence struct {
pendingInputs []input
// tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string
pendingResponses []llm.CompletionResponse
// input cache being used by this sequence
cache *InputCacheSlot
@ -61,7 +61,7 @@ type Sequence struct {
crossAttention bool
// channel to send responses over
responses chan string
responses chan llm.CompletionResponse
// channel to stop decoding (such as if the remote connection is closed)
quit chan bool
@ -150,8 +150,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
numPromptInputs: len(inputs),
startProcessingTime: startTime,
numPredict: params.numPredict,
pendingResponses: make([]string, 0),
responses: make(chan string, 100),
pendingResponses: make([]llm.CompletionResponse, 0),
responses: make(chan llm.CompletionResponse, 100),
quit: make(chan bool, 1),
embedding: make(chan []float32, 1),
samplingCtx: sc,
@ -276,29 +276,28 @@ func (s *Server) allNil() bool {
}
func flushPending(seq *Sequence) bool {
joined := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = []string{}
pending := seq.pendingResponses
seq.pendingResponses = []llm.CompletionResponse{}
// Check if there are any partial UTF-8 characters remaining.
// We already check and queue as we are generating but some may
// still make it here:
// - Sequence is ending, e.g. generation limit has been hit
// - Invalid characters in the middle of a string
// This is a stricter check to ensure we never output invalid Unicode.
for !utf8.ValidString(joined) {
joined = joined[:len(joined)-1]
}
for i, r := range pending {
if i == len(pending)-1 {
// Check and trim any trailing partial UTF-8 characters
content := r.Content
for !utf8.ValidString(content) {
content = content[:len(content)-1]
}
r.Content = content
}
if len(joined) == 0 {
return true
}
select {
case seq.responses <- joined:
return true
case <-seq.quit:
return false
select {
case seq.responses <- r:
return true
case <-seq.quit:
return false
}
}
// no pending responses to send
return true
}
func (s *Server) removeSequence(seqIndex int, reason string) {
@ -497,8 +496,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
seq.inputs = []input{{token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = append(seq.pendingResponses, llm.CompletionResponse{Content: piece})
sequence := ""
for _, r := range seq.pendingResponses {
sequence += r.Content
}
if ok, stop := common.FindStop(sequence, seq.stop); ok {
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
@ -637,9 +639,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
case content, ok := <-seq.responses:
if ok {
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Content: content,
}); err != nil {
if err := json.NewEncoder(w).Encode(&content); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
close(seq.quit)
return

View File

@ -53,13 +53,13 @@ type Sequence struct {
pendingInputs []input.Input
// tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string
pendingResponses []llm.CompletionResponse
// input cache being used by this sequence
cache *InputCacheSlot
// channel to send responses over
responses chan string
responses chan llm.CompletionResponse
// channel to stop decoding (such as if the remote connection is closed)
quit chan bool
@ -138,8 +138,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
numPromptInputs: len(inputs),
startProcessingTime: startTime,
numPredict: params.numPredict,
pendingResponses: make([]string, 0),
responses: make(chan string, 100),
pendingResponses: make([]llm.CompletionResponse, 0),
responses: make(chan llm.CompletionResponse, 100),
quit: make(chan bool, 1),
embedding: make(chan []float32, 1),
sampler: params.sampler,
@ -288,29 +288,28 @@ func (s *Server) allNil() bool {
}
func flushPending(seq *Sequence) bool {
joined := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = []string{}
pending := seq.pendingResponses
seq.pendingResponses = []llm.CompletionResponse{}
// Check if there are any partial UTF-8 characters remaining.
// We already check and queue as we are generating but some may
// still make it here:
// - Sequence is ending, e.g. generation limit has been hit
// - Invalid characters in the middle of a string
// This is a stricter check to ensure we never output invalid Unicode.
for !utf8.ValidString(joined) {
joined = joined[:len(joined)-1]
}
for i, r := range pending {
if i == len(pending)-1 {
// Check and trim any trailing partial UTF-8 characters
content := r.Content
for !utf8.ValidString(content) {
content = content[:len(content)-1]
}
r.Content = content
}
if len(joined) == 0 {
return true
}
select {
case seq.responses <- joined:
return true
case <-seq.quit:
return false
select {
case seq.responses <- r:
return true
case <-seq.quit:
return false
}
}
// no pending responses to send
return true
}
func (s *Server) removeSequence(seqIndex int, reason string) {
@ -484,8 +483,11 @@ func (s *Server) processBatch() error {
seq.inputs = []input.Input{{Token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = append(seq.pendingResponses, llm.CompletionResponse{Content: piece})
sequence := ""
for _, r := range seq.pendingResponses {
sequence += r.Content
}
if ok, stop := common.FindStop(sequence, seq.stop); ok {
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
@ -623,9 +625,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
case content, ok := <-seq.responses:
if ok {
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Content: content,
}); err != nil {
if err := json.NewEncoder(w).Encode(&content); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
close(seq.quit)
return