Compare commits
2 Commits
main
...
brucemacd/
Author | SHA1 | Date | |
---|---|---|---|
![]() |
946fdd5388 | ||
![]() |
905da35468 |
@ -2,6 +2,8 @@ package common
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
func FindStop(sequence string, stops []string) (bool, string) {
|
||||
@ -29,40 +31,43 @@ func ContainsStopSuffix(sequence string, stops []string) bool {
|
||||
// truncateStop removes the provided stop string from pieces,
|
||||
// returning the partial pieces with stop removed, including truncating
|
||||
// the last piece if required (and signalling if this was the case)
|
||||
func TruncateStop(pieces []string, stop string) ([]string, bool) {
|
||||
joined := strings.Join(pieces, "")
|
||||
|
||||
index := strings.Index(joined, stop)
|
||||
if index == -1 {
|
||||
return pieces, false
|
||||
func TruncateStop(resps []llm.CompletionResponse, stop string) ([]llm.CompletionResponse, bool) {
|
||||
var sequence string
|
||||
for _, resp := range resps {
|
||||
sequence += resp.Content
|
||||
}
|
||||
|
||||
joined = joined[:index]
|
||||
|
||||
// Split truncated string back into pieces of original lengths
|
||||
lengths := make([]int, len(pieces))
|
||||
for i, piece := range pieces {
|
||||
lengths[i] = len(piece)
|
||||
idx := strings.Index(sequence, stop)
|
||||
if idx < 0 {
|
||||
return resps, false
|
||||
}
|
||||
|
||||
var result []string
|
||||
tokenTruncated := false
|
||||
start := 0
|
||||
for _, length := range lengths {
|
||||
if start >= len(joined) {
|
||||
truncated := sequence[:idx]
|
||||
if len(truncated) == 0 {
|
||||
return nil, true
|
||||
}
|
||||
|
||||
result := make([]llm.CompletionResponse, 0, len(resps))
|
||||
|
||||
// Track position in truncated sequence
|
||||
pos := 0
|
||||
truncationHappened := false
|
||||
for _, resp := range resps {
|
||||
if pos >= len(truncated) {
|
||||
break
|
||||
}
|
||||
|
||||
end := start + length
|
||||
if end > len(joined) {
|
||||
end = len(joined)
|
||||
tokenTruncated = true
|
||||
chunk := truncated[pos:min(pos+len(resp.Content), len(truncated))]
|
||||
if len(chunk) < len(resp.Content) {
|
||||
truncationHappened = true
|
||||
}
|
||||
result = append(result, joined[start:end])
|
||||
start = end
|
||||
if len(chunk) > 0 {
|
||||
result = append(result, llm.CompletionResponse{Content: chunk})
|
||||
}
|
||||
pos += len(resp.Content)
|
||||
}
|
||||
|
||||
return result, tokenTruncated
|
||||
return result, truncationHappened
|
||||
}
|
||||
|
||||
func IncompleteUnicode(token string) bool {
|
||||
|
@ -1,51 +1,84 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
func TestTruncateStop(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pieces []string
|
||||
pieces []llm.CompletionResponse
|
||||
stop string
|
||||
expected []string
|
||||
expected []llm.CompletionResponse
|
||||
expectedTrunc bool
|
||||
}{
|
||||
{
|
||||
name: "Single word",
|
||||
pieces: []string{"hello", "world"},
|
||||
stop: "world",
|
||||
expected: []string{"hello"},
|
||||
name: "Single word",
|
||||
pieces: []llm.CompletionResponse{
|
||||
{Content: "Hello"},
|
||||
{Content: "world"},
|
||||
},
|
||||
stop: "world",
|
||||
expected: []llm.CompletionResponse{
|
||||
{Content: "Hello"},
|
||||
},
|
||||
expectedTrunc: false,
|
||||
},
|
||||
{
|
||||
name: "Partial",
|
||||
pieces: []string{"hello", "wor"},
|
||||
stop: "or",
|
||||
expected: []string{"hello", "w"},
|
||||
name: "Partial",
|
||||
pieces: []llm.CompletionResponse{
|
||||
{Content: "Hello"},
|
||||
{Content: " wor"},
|
||||
},
|
||||
stop: "or",
|
||||
expected: []llm.CompletionResponse{
|
||||
{Content: "Hello"},
|
||||
{Content: " w"},
|
||||
},
|
||||
expectedTrunc: true,
|
||||
},
|
||||
{
|
||||
name: "Suffix",
|
||||
pieces: []string{"Hello", " there", "!"},
|
||||
stop: "!",
|
||||
expected: []string{"Hello", " there"},
|
||||
name: "Suffix",
|
||||
pieces: []llm.CompletionResponse{
|
||||
{Content: "Hello"},
|
||||
{Content: " there"},
|
||||
{Content: "!"},
|
||||
},
|
||||
stop: "!",
|
||||
expected: []llm.CompletionResponse{
|
||||
{Content: "Hello"},
|
||||
{Content: " there"},
|
||||
},
|
||||
expectedTrunc: false,
|
||||
},
|
||||
{
|
||||
name: "Suffix partial",
|
||||
pieces: []string{"Hello", " the", "re!"},
|
||||
stop: "there!",
|
||||
expected: []string{"Hello", " "},
|
||||
name: "Suffix partial",
|
||||
pieces: []llm.CompletionResponse{
|
||||
{Content: "Hello"},
|
||||
{Content: " the"},
|
||||
{Content: "re!"},
|
||||
},
|
||||
stop: "there!",
|
||||
expected: []llm.CompletionResponse{
|
||||
{Content: "Hello"},
|
||||
{Content: " "},
|
||||
},
|
||||
expectedTrunc: true,
|
||||
},
|
||||
{
|
||||
name: "Middle",
|
||||
pieces: []string{"hello", " wor"},
|
||||
stop: "llo w",
|
||||
expected: []string{"he"},
|
||||
name: "Middle",
|
||||
pieces: []llm.CompletionResponse{
|
||||
{Content: "Hello"},
|
||||
{Content: " wo"},
|
||||
},
|
||||
stop: "llo w",
|
||||
expected: []llm.CompletionResponse{
|
||||
{Content: "He"},
|
||||
},
|
||||
expectedTrunc: true,
|
||||
},
|
||||
}
|
||||
@ -54,12 +87,27 @@ func TestTruncateStop(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, resultTrunc := TruncateStop(tt.pieces, tt.stop)
|
||||
if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
|
||||
t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc)
|
||||
t.Errorf("truncateStop(%v, %v):\n%shave truncated %v\nwant truncated %v",
|
||||
tt.pieces, tt.stop, formatContentDiff(result, tt.expected), resultTrunc, tt.expectedTrunc)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func formatContentDiff(result, expected []llm.CompletionResponse) string {
|
||||
var s string
|
||||
for i := 0; i < len(result) || i < len(expected); i++ {
|
||||
if i < len(result) && i < len(expected) && result[i].Content != expected[i].Content {
|
||||
s += fmt.Sprintf("[%d] %q vs %q\n", i, result[i].Content, expected[i].Content)
|
||||
} else if i < len(result) && i >= len(expected) {
|
||||
s += fmt.Sprintf("[%d] extra %q\n", i, result[i].Content)
|
||||
} else if i >= len(result) && i < len(expected) {
|
||||
s += fmt.Sprintf("[%d] missing %q\n", i, expected[i].Content)
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func TestIncompleteUnicode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -51,7 +51,7 @@ type Sequence struct {
|
||||
pendingInputs []input
|
||||
|
||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||
pendingResponses []string
|
||||
pendingResponses []llm.CompletionResponse
|
||||
|
||||
// input cache being used by this sequence
|
||||
cache *InputCacheSlot
|
||||
@ -61,7 +61,7 @@ type Sequence struct {
|
||||
crossAttention bool
|
||||
|
||||
// channel to send responses over
|
||||
responses chan string
|
||||
responses chan llm.CompletionResponse
|
||||
|
||||
// channel to stop decoding (such as if the remote connection is closed)
|
||||
quit chan bool
|
||||
@ -150,8 +150,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
numPromptInputs: len(inputs),
|
||||
startProcessingTime: startTime,
|
||||
numPredict: params.numPredict,
|
||||
pendingResponses: make([]string, 0),
|
||||
responses: make(chan string, 100),
|
||||
pendingResponses: make([]llm.CompletionResponse, 0),
|
||||
responses: make(chan llm.CompletionResponse, 100),
|
||||
quit: make(chan bool, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
samplingCtx: sc,
|
||||
@ -276,29 +276,28 @@ func (s *Server) allNil() bool {
|
||||
}
|
||||
|
||||
func flushPending(seq *Sequence) bool {
|
||||
joined := strings.Join(seq.pendingResponses, "")
|
||||
seq.pendingResponses = []string{}
|
||||
pending := seq.pendingResponses
|
||||
seq.pendingResponses = []llm.CompletionResponse{}
|
||||
|
||||
// Check if there are any partial UTF-8 characters remaining.
|
||||
// We already check and queue as we are generating but some may
|
||||
// still make it here:
|
||||
// - Sequence is ending, e.g. generation limit has been hit
|
||||
// - Invalid characters in the middle of a string
|
||||
// This is a stricter check to ensure we never output invalid Unicode.
|
||||
for !utf8.ValidString(joined) {
|
||||
joined = joined[:len(joined)-1]
|
||||
}
|
||||
for i, r := range pending {
|
||||
if i == len(pending)-1 {
|
||||
// Check and trim any trailing partial UTF-8 characters
|
||||
content := r.Content
|
||||
for !utf8.ValidString(content) {
|
||||
content = content[:len(content)-1]
|
||||
}
|
||||
r.Content = content
|
||||
}
|
||||
|
||||
if len(joined) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
select {
|
||||
case seq.responses <- joined:
|
||||
return true
|
||||
case <-seq.quit:
|
||||
return false
|
||||
select {
|
||||
case seq.responses <- r:
|
||||
return true
|
||||
case <-seq.quit:
|
||||
return false
|
||||
}
|
||||
}
|
||||
// no pending responses to send
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) removeSequence(seqIndex int, reason string) {
|
||||
@ -497,8 +496,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
|
||||
seq.inputs = []input{{token: token}}
|
||||
|
||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||
sequence := strings.Join(seq.pendingResponses, "")
|
||||
seq.pendingResponses = append(seq.pendingResponses, llm.CompletionResponse{Content: piece})
|
||||
sequence := ""
|
||||
for _, r := range seq.pendingResponses {
|
||||
sequence += r.Content
|
||||
}
|
||||
|
||||
if ok, stop := common.FindStop(sequence, seq.stop); ok {
|
||||
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
||||
@ -637,9 +639,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
case content, ok := <-seq.responses:
|
||||
if ok {
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
Content: content,
|
||||
}); err != nil {
|
||||
if err := json.NewEncoder(w).Encode(&content); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
close(seq.quit)
|
||||
return
|
||||
|
@ -53,13 +53,13 @@ type Sequence struct {
|
||||
pendingInputs []input.Input
|
||||
|
||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||
pendingResponses []string
|
||||
pendingResponses []llm.CompletionResponse
|
||||
|
||||
// input cache being used by this sequence
|
||||
cache *InputCacheSlot
|
||||
|
||||
// channel to send responses over
|
||||
responses chan string
|
||||
responses chan llm.CompletionResponse
|
||||
|
||||
// channel to stop decoding (such as if the remote connection is closed)
|
||||
quit chan bool
|
||||
@ -138,8 +138,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
numPromptInputs: len(inputs),
|
||||
startProcessingTime: startTime,
|
||||
numPredict: params.numPredict,
|
||||
pendingResponses: make([]string, 0),
|
||||
responses: make(chan string, 100),
|
||||
pendingResponses: make([]llm.CompletionResponse, 0),
|
||||
responses: make(chan llm.CompletionResponse, 100),
|
||||
quit: make(chan bool, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
sampler: params.sampler,
|
||||
@ -288,29 +288,28 @@ func (s *Server) allNil() bool {
|
||||
}
|
||||
|
||||
func flushPending(seq *Sequence) bool {
|
||||
joined := strings.Join(seq.pendingResponses, "")
|
||||
seq.pendingResponses = []string{}
|
||||
pending := seq.pendingResponses
|
||||
seq.pendingResponses = []llm.CompletionResponse{}
|
||||
|
||||
// Check if there are any partial UTF-8 characters remaining.
|
||||
// We already check and queue as we are generating but some may
|
||||
// still make it here:
|
||||
// - Sequence is ending, e.g. generation limit has been hit
|
||||
// - Invalid characters in the middle of a string
|
||||
// This is a stricter check to ensure we never output invalid Unicode.
|
||||
for !utf8.ValidString(joined) {
|
||||
joined = joined[:len(joined)-1]
|
||||
}
|
||||
for i, r := range pending {
|
||||
if i == len(pending)-1 {
|
||||
// Check and trim any trailing partial UTF-8 characters
|
||||
content := r.Content
|
||||
for !utf8.ValidString(content) {
|
||||
content = content[:len(content)-1]
|
||||
}
|
||||
r.Content = content
|
||||
}
|
||||
|
||||
if len(joined) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
select {
|
||||
case seq.responses <- joined:
|
||||
return true
|
||||
case <-seq.quit:
|
||||
return false
|
||||
select {
|
||||
case seq.responses <- r:
|
||||
return true
|
||||
case <-seq.quit:
|
||||
return false
|
||||
}
|
||||
}
|
||||
// no pending responses to send
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) removeSequence(seqIndex int, reason string) {
|
||||
@ -484,8 +483,11 @@ func (s *Server) processBatch() error {
|
||||
|
||||
seq.inputs = []input.Input{{Token: token}}
|
||||
|
||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||
sequence := strings.Join(seq.pendingResponses, "")
|
||||
seq.pendingResponses = append(seq.pendingResponses, llm.CompletionResponse{Content: piece})
|
||||
sequence := ""
|
||||
for _, r := range seq.pendingResponses {
|
||||
sequence += r.Content
|
||||
}
|
||||
|
||||
if ok, stop := common.FindStop(sequence, seq.stop); ok {
|
||||
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
||||
@ -623,9 +625,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
case content, ok := <-seq.responses:
|
||||
if ok {
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
Content: content,
|
||||
}); err != nil {
|
||||
if err := json.NewEncoder(w).Encode(&content); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
close(seq.quit)
|
||||
return
|
||||
|
Loading…
x
Reference in New Issue
Block a user