diff --git a/runner/common/stop.go b/runner/common/stop.go index 3f27a286e..748f3f0b7 100644 --- a/runner/common/stop.go +++ b/runner/common/stop.go @@ -29,40 +29,43 @@ func ContainsStopSuffix(sequence string, stops []string) bool { // truncateStop removes the provided stop string from pieces, // returning the partial pieces with stop removed, including truncating // the last piece if required (and signalling if this was the case) -func TruncateStop(pieces []string, stop string) ([]string, bool) { - joined := strings.Join(pieces, "") - - index := strings.Index(joined, stop) - if index == -1 { - return pieces, false +func TruncateStop(resps []CompletionResponse, stop string) ([]CompletionResponse, bool) { + var sequence string + for _, resp := range resps { + sequence += resp.Content } - joined = joined[:index] - - // Split truncated string back into pieces of original lengths - lengths := make([]int, len(pieces)) - for i, piece := range pieces { - lengths[i] = len(piece) + idx := strings.Index(sequence, stop) + if idx < 0 { + return resps, false } - var result []string - tokenTruncated := false - start := 0 - for _, length := range lengths { - if start >= len(joined) { + truncated := sequence[:idx] + if len(truncated) == 0 { + return nil, true + } + + result := make([]CompletionResponse, 0, len(resps)) + + // Track position in truncated sequence + pos := 0 + truncationHappened := false + for _, resp := range resps { + if pos >= len(truncated) { break } - end := start + length - if end > len(joined) { - end = len(joined) - tokenTruncated = true + chunk := truncated[pos:min(pos+len(resp.Content), len(truncated))] + if len(chunk) < len(resp.Content) { + truncationHappened = true } - result = append(result, joined[start:end]) - start = end + if len(chunk) > 0 { + result = append(result, CompletionResponse{Content: chunk}) + } + pos += len(resp.Content) } - return result, tokenTruncated + return result, truncationHappened } func IncompleteUnicode(token string) bool { diff --git a/runner/common/stop_test.go b/runner/common/stop_test.go index 8df267eb4..df68172fa 100644 --- a/runner/common/stop_test.go +++ b/runner/common/stop_test.go @@ -1,6 +1,7 @@ package common import ( + "fmt" "reflect" "testing" ) @@ -8,44 +9,74 @@ import ( func TestTruncateStop(t *testing.T) { tests := []struct { name string - pieces []string + pieces []CompletionResponse stop string - expected []string + expected []CompletionResponse expectedTrunc bool }{ { - name: "Single word", - pieces: []string{"hello", "world"}, - stop: "world", - expected: []string{"hello"}, + name: "Single word", + pieces: []CompletionResponse{ + {Content: "Hello"}, + {Content: "world"}, + }, + stop: "world", + expected: []CompletionResponse{ + {Content: "Hello"}, + }, expectedTrunc: false, }, { - name: "Partial", - pieces: []string{"hello", "wor"}, - stop: "or", - expected: []string{"hello", "w"}, + name: "Partial", + pieces: []CompletionResponse{ + {Content: "Hello"}, + {Content: " wor"}, + }, + stop: "or", + expected: []CompletionResponse{ + {Content: "Hello"}, + {Content: " w"}, + }, expectedTrunc: true, }, { - name: "Suffix", - pieces: []string{"Hello", " there", "!"}, - stop: "!", - expected: []string{"Hello", " there"}, + name: "Suffix", + pieces: []CompletionResponse{ + {Content: "Hello"}, + {Content: " there"}, + {Content: "!"}, + }, + stop: "!", + expected: []CompletionResponse{ + {Content: "Hello"}, + {Content: " there"}, + }, expectedTrunc: false, }, { - name: "Suffix partial", - pieces: []string{"Hello", " the", "re!"}, - stop: "there!", - expected: []string{"Hello", " "}, + name: "Suffix partial", + pieces: []CompletionResponse{ + {Content: "Hello"}, + {Content: " the"}, + {Content: "re!"}, + }, + stop: "there!", + expected: []CompletionResponse{ + {Content: "Hello"}, + {Content: " "}, + }, expectedTrunc: true, }, { - name: "Middle", - pieces: []string{"hello", " wor"}, - stop: "llo w", - expected: []string{"he"}, + name: "Middle", + pieces: []CompletionResponse{ + {Content: "Hello"}, + {Content: " wo"}, + }, + stop: "llo w", + expected: []CompletionResponse{ + {Content: "He"}, + }, expectedTrunc: true, }, } @@ -54,12 +85,27 @@ func TestTruncateStop(t *testing.T) { t.Run(tt.name, func(t *testing.T) { result, resultTrunc := TruncateStop(tt.pieces, tt.stop) if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc { - t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc) + t.Errorf("truncateStop(%v, %v):\n%shave truncated %v\nwant truncated %v", + tt.pieces, tt.stop, formatContentDiff(result, tt.expected), resultTrunc, tt.expectedTrunc) } }) } } +func formatContentDiff(result, expected []CompletionResponse) string { + var s string + for i := 0; i < len(result) || i < len(expected); i++ { + if i < len(result) && i < len(expected) && result[i].Content != expected[i].Content { + s += fmt.Sprintf("[%d] %q vs %q\n", i, result[i].Content, expected[i].Content) + } else if i < len(result) && i >= len(expected) { + s += fmt.Sprintf("[%d] extra %q\n", i, result[i].Content) + } else if i >= len(result) && i < len(expected) { + s += fmt.Sprintf("[%d] missing %q\n", i, expected[i].Content) + } + } + return s +} + func TestIncompleteUnicode(t *testing.T) { tests := []struct { name string diff --git a/runner/common/types.go b/runner/common/types.go new file mode 100644 index 000000000..51487540a --- /dev/null +++ b/runner/common/types.go @@ -0,0 +1,23 @@ +package common + +type CompletionResponse struct { + Content string `json:"content"` + Stop bool `json:"stop"` + + Model string `json:"model,omitempty"` + Prompt string `json:"prompt,omitempty"` + StoppedLimit bool `json:"stopped_limit,omitempty"` + PredictedN int `json:"predicted_n,omitempty"` + PredictedMS float64 `json:"predicted_ms,omitempty"` + PromptN int `json:"prompt_n,omitempty"` + PromptMS float64 `json:"prompt_ms,omitempty"` + + Timings Timings `json:"timings"` +} + +type Timings struct { + PredictedN int `json:"predicted_n"` + PredictedMS float64 `json:"predicted_ms"` + PromptN int `json:"prompt_n"` + PromptMS float64 `json:"prompt_ms"` +} diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index 83802d604..accbfddb3 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -51,7 +51,7 @@ type Sequence struct { pendingInputs []input // tokens that have been generated but not returned yet (e.g. for stop sequences) - pendingResponses []string + pendingResponses []common.CompletionResponse // input cache being used by this sequence cache *InputCacheSlot @@ -61,7 +61,7 @@ type Sequence struct { crossAttention bool // channel to send responses over - responses chan string + responses chan common.CompletionResponse // channel to stop decoding (such as if the remote connection is closed) quit chan bool @@ -150,8 +150,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe numPromptInputs: len(inputs), startProcessingTime: startTime, numPredict: params.numPredict, - pendingResponses: make([]string, 0), - responses: make(chan string, 100), + pendingResponses: make([]common.CompletionResponse, 0), + responses: make(chan common.CompletionResponse, 100), quit: make(chan bool, 1), embedding: make(chan []float32, 1), samplingCtx: sc, @@ -276,29 +276,28 @@ func (s *Server) allNil() bool { } func flushPending(seq *Sequence) bool { - joined := strings.Join(seq.pendingResponses, "") - seq.pendingResponses = []string{} + pending := seq.pendingResponses + seq.pendingResponses = []common.CompletionResponse{} - // Check if there are any partial UTF-8 characters remaining. - // We already check and queue as we are generating but some may - // still make it here: - // - Sequence is ending, e.g. generation limit has been hit - // - Invalid characters in the middle of a string - // This is a stricter check to ensure we never output invalid Unicode. - for !utf8.ValidString(joined) { - joined = joined[:len(joined)-1] - } + for i, r := range pending { + if i == len(pending)-1 { + // Check and trim any trailing partial UTF-8 characters + content := r.Content + for !utf8.ValidString(content) { + content = content[:len(content)-1] + } + r.Content = content + } - if len(joined) == 0 { - return true - } - - select { - case seq.responses <- joined: - return true - case <-seq.quit: - return false + select { + case seq.responses <- r: + return true + case <-seq.quit: + return false + } } + // no pending responses to send + return true } func (s *Server) removeSequence(seqIndex int, reason string) { @@ -497,8 +496,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) seq.inputs = []input{{token: token}} - seq.pendingResponses = append(seq.pendingResponses, piece) - sequence := strings.Join(seq.pendingResponses, "") + seq.pendingResponses = append(seq.pendingResponses, common.CompletionResponse{Content: piece}) + sequence := "" + for _, r := range seq.pendingResponses { + sequence += r.Content + } if ok, stop := common.FindStop(sequence, seq.stop); ok { slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop) diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 9a1a549cd..11385a8a0 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -53,13 +53,13 @@ type Sequence struct { pendingInputs []input.Input // tokens that have been generated but not returned yet (e.g. for stop sequences) - pendingResponses []string + pendingResponses []common.CompletionResponse // input cache being used by this sequence cache *InputCacheSlot // channel to send responses over - responses chan string + responses chan common.CompletionResponse // channel to stop decoding (such as if the remote connection is closed) quit chan bool @@ -138,8 +138,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe numPromptInputs: len(inputs), startProcessingTime: startTime, numPredict: params.numPredict, - pendingResponses: make([]string, 0), - responses: make(chan string, 100), + pendingResponses: make([]common.CompletionResponse, 0), + responses: make(chan common.CompletionResponse, 100), quit: make(chan bool, 1), embedding: make(chan []float32, 1), sampler: params.sampler, @@ -288,29 +288,28 @@ func (s *Server) allNil() bool { } func flushPending(seq *Sequence) bool { - joined := strings.Join(seq.pendingResponses, "") - seq.pendingResponses = []string{} + pending := seq.pendingResponses + seq.pendingResponses = []common.CompletionResponse{} - // Check if there are any partial UTF-8 characters remaining. - // We already check and queue as we are generating but some may - // still make it here: - // - Sequence is ending, e.g. generation limit has been hit - // - Invalid characters in the middle of a string - // This is a stricter check to ensure we never output invalid Unicode. - for !utf8.ValidString(joined) { - joined = joined[:len(joined)-1] - } + for i, r := range pending { + if i == len(pending)-1 { + // Check and trim any trailing partial UTF-8 characters + content := r.Content + for !utf8.ValidString(content) { + content = content[:len(content)-1] + } + r.Content = content + } - if len(joined) == 0 { - return true - } - - select { - case seq.responses <- joined: - return true - case <-seq.quit: - return false + select { + case seq.responses <- r: + return true + case <-seq.quit: + return false + } } + // no pending responses to send + return true } func (s *Server) removeSequence(seqIndex int, reason string) { @@ -484,8 +483,11 @@ func (s *Server) processBatch() error { seq.inputs = []input.Input{{Token: token}} - seq.pendingResponses = append(seq.pendingResponses, piece) - sequence := strings.Join(seq.pendingResponses, "") + seq.pendingResponses = append(seq.pendingResponses, common.CompletionResponse{Content: piece}) + sequence := "" + for _, r := range seq.pendingResponses { + sequence += r.Content + } if ok, stop := common.FindStop(sequence, seq.stop); ok { slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)