From 905da35468ccca697425d65509bafb990f52af80 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Fri, 21 Feb 2025 16:31:31 -0800 Subject: [PATCH] runner: enable returning more info from runner processing Currently we return only the text predicted from the LLM. This was nice in that it was simple, but there may be other info we want to know from the processing. This change adds the ability to return more information from the runner than just the text predicted. A follow up change will add logprobs to the response returned from the runner using this structure. --- runner/common/stop.go | 51 ++++++++++--------- runner/common/stop_test.go | 92 ++++++++++++++++++++++++++--------- runner/common/types.go | 23 +++++++++ runner/llamarunner/runner.go | 54 ++++++++++---------- runner/ollamarunner/runner.go | 54 ++++++++++---------- 5 files changed, 175 insertions(+), 99 deletions(-) create mode 100644 runner/common/types.go diff --git a/runner/common/stop.go b/runner/common/stop.go index 3f27a286e..748f3f0b7 100644 --- a/runner/common/stop.go +++ b/runner/common/stop.go @@ -29,40 +29,43 @@ func ContainsStopSuffix(sequence string, stops []string) bool { // truncateStop removes the provided stop string from pieces, // returning the partial pieces with stop removed, including truncating // the last piece if required (and signalling if this was the case) -func TruncateStop(pieces []string, stop string) ([]string, bool) { - joined := strings.Join(pieces, "") - - index := strings.Index(joined, stop) - if index == -1 { - return pieces, false +func TruncateStop(resps []CompletionResponse, stop string) ([]CompletionResponse, bool) { + var sequence string + for _, resp := range resps { + sequence += resp.Content } - joined = joined[:index] - - // Split truncated string back into pieces of original lengths - lengths := make([]int, len(pieces)) - for i, piece := range pieces { - lengths[i] = len(piece) + idx := strings.Index(sequence, stop) + if idx < 0 { + return resps, false } - var result []string - tokenTruncated := false - start := 0 - for _, length := range lengths { - if start >= len(joined) { + truncated := sequence[:idx] + if len(truncated) == 0 { + return nil, true + } + + result := make([]CompletionResponse, 0, len(resps)) + + // Track position in truncated sequence + pos := 0 + truncationHappened := false + for _, resp := range resps { + if pos >= len(truncated) { break } - end := start + length - if end > len(joined) { - end = len(joined) - tokenTruncated = true + chunk := truncated[pos:min(pos+len(resp.Content), len(truncated))] + if len(chunk) < len(resp.Content) { + truncationHappened = true } - result = append(result, joined[start:end]) - start = end + if len(chunk) > 0 { + result = append(result, CompletionResponse{Content: chunk}) + } + pos += len(resp.Content) } - return result, tokenTruncated + return result, truncationHappened } func IncompleteUnicode(token string) bool { diff --git a/runner/common/stop_test.go b/runner/common/stop_test.go index 8df267eb4..df68172fa 100644 --- a/runner/common/stop_test.go +++ b/runner/common/stop_test.go @@ -1,6 +1,7 @@ package common import ( + "fmt" "reflect" "testing" ) @@ -8,44 +9,74 @@ import ( func TestTruncateStop(t *testing.T) { tests := []struct { name string - pieces []string + pieces []CompletionResponse stop string - expected []string + expected []CompletionResponse expectedTrunc bool }{ { - name: "Single word", - pieces: []string{"hello", "world"}, - stop: "world", - expected: []string{"hello"}, + name: "Single word", + pieces: []CompletionResponse{ + {Content: "Hello"}, + {Content: "world"}, + }, + stop: "world", + expected: []CompletionResponse{ + {Content: "Hello"}, + }, expectedTrunc: false, }, { - name: "Partial", - pieces: []string{"hello", "wor"}, - stop: "or", - expected: []string{"hello", "w"}, + name: "Partial", + pieces: []CompletionResponse{ + {Content: "Hello"}, + {Content: " wor"}, + }, + stop: "or", + expected: []CompletionResponse{ + {Content: "Hello"}, + {Content: " w"}, + }, expectedTrunc: true, }, { - name: "Suffix", - pieces: []string{"Hello", " there", "!"}, - stop: "!", - expected: []string{"Hello", " there"}, + name: "Suffix", + pieces: []CompletionResponse{ + {Content: "Hello"}, + {Content: " there"}, + {Content: "!"}, + }, + stop: "!", + expected: []CompletionResponse{ + {Content: "Hello"}, + {Content: " there"}, + }, expectedTrunc: false, }, { - name: "Suffix partial", - pieces: []string{"Hello", " the", "re!"}, - stop: "there!", - expected: []string{"Hello", " "}, + name: "Suffix partial", + pieces: []CompletionResponse{ + {Content: "Hello"}, + {Content: " the"}, + {Content: "re!"}, + }, + stop: "there!", + expected: []CompletionResponse{ + {Content: "Hello"}, + {Content: " "}, + }, expectedTrunc: true, }, { - name: "Middle", - pieces: []string{"hello", " wor"}, - stop: "llo w", - expected: []string{"he"}, + name: "Middle", + pieces: []CompletionResponse{ + {Content: "Hello"}, + {Content: " wo"}, + }, + stop: "llo w", + expected: []CompletionResponse{ + {Content: "He"}, + }, expectedTrunc: true, }, } @@ -54,12 +85,27 @@ func TestTruncateStop(t *testing.T) { t.Run(tt.name, func(t *testing.T) { result, resultTrunc := TruncateStop(tt.pieces, tt.stop) if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc { - t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc) + t.Errorf("truncateStop(%v, %v):\n%shave truncated %v\nwant truncated %v", + tt.pieces, tt.stop, formatContentDiff(result, tt.expected), resultTrunc, tt.expectedTrunc) } }) } } +func formatContentDiff(result, expected []CompletionResponse) string { + var s string + for i := 0; i < len(result) || i < len(expected); i++ { + if i < len(result) && i < len(expected) && result[i].Content != expected[i].Content { + s += fmt.Sprintf("[%d] %q vs %q\n", i, result[i].Content, expected[i].Content) + } else if i < len(result) && i >= len(expected) { + s += fmt.Sprintf("[%d] extra %q\n", i, result[i].Content) + } else if i >= len(result) && i < len(expected) { + s += fmt.Sprintf("[%d] missing %q\n", i, expected[i].Content) + } + } + return s +} + func TestIncompleteUnicode(t *testing.T) { tests := []struct { name string diff --git a/runner/common/types.go b/runner/common/types.go new file mode 100644 index 000000000..51487540a --- /dev/null +++ b/runner/common/types.go @@ -0,0 +1,23 @@ +package common + +type CompletionResponse struct { + Content string `json:"content"` + Stop bool `json:"stop"` + + Model string `json:"model,omitempty"` + Prompt string `json:"prompt,omitempty"` + StoppedLimit bool `json:"stopped_limit,omitempty"` + PredictedN int `json:"predicted_n,omitempty"` + PredictedMS float64 `json:"predicted_ms,omitempty"` + PromptN int `json:"prompt_n,omitempty"` + PromptMS float64 `json:"prompt_ms,omitempty"` + + Timings Timings `json:"timings"` +} + +type Timings struct { + PredictedN int `json:"predicted_n"` + PredictedMS float64 `json:"predicted_ms"` + PromptN int `json:"prompt_n"` + PromptMS float64 `json:"prompt_ms"` +} diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index 83802d604..accbfddb3 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -51,7 +51,7 @@ type Sequence struct { pendingInputs []input // tokens that have been generated but not returned yet (e.g. for stop sequences) - pendingResponses []string + pendingResponses []common.CompletionResponse // input cache being used by this sequence cache *InputCacheSlot @@ -61,7 +61,7 @@ type Sequence struct { crossAttention bool // channel to send responses over - responses chan string + responses chan common.CompletionResponse // channel to stop decoding (such as if the remote connection is closed) quit chan bool @@ -150,8 +150,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe numPromptInputs: len(inputs), startProcessingTime: startTime, numPredict: params.numPredict, - pendingResponses: make([]string, 0), - responses: make(chan string, 100), + pendingResponses: make([]common.CompletionResponse, 0), + responses: make(chan common.CompletionResponse, 100), quit: make(chan bool, 1), embedding: make(chan []float32, 1), samplingCtx: sc, @@ -276,29 +276,28 @@ func (s *Server) allNil() bool { } func flushPending(seq *Sequence) bool { - joined := strings.Join(seq.pendingResponses, "") - seq.pendingResponses = []string{} + pending := seq.pendingResponses + seq.pendingResponses = []common.CompletionResponse{} - // Check if there are any partial UTF-8 characters remaining. - // We already check and queue as we are generating but some may - // still make it here: - // - Sequence is ending, e.g. generation limit has been hit - // - Invalid characters in the middle of a string - // This is a stricter check to ensure we never output invalid Unicode. - for !utf8.ValidString(joined) { - joined = joined[:len(joined)-1] - } + for i, r := range pending { + if i == len(pending)-1 { + // Check and trim any trailing partial UTF-8 characters + content := r.Content + for !utf8.ValidString(content) { + content = content[:len(content)-1] + } + r.Content = content + } - if len(joined) == 0 { - return true - } - - select { - case seq.responses <- joined: - return true - case <-seq.quit: - return false + select { + case seq.responses <- r: + return true + case <-seq.quit: + return false + } } + // no pending responses to send + return true } func (s *Server) removeSequence(seqIndex int, reason string) { @@ -497,8 +496,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) seq.inputs = []input{{token: token}} - seq.pendingResponses = append(seq.pendingResponses, piece) - sequence := strings.Join(seq.pendingResponses, "") + seq.pendingResponses = append(seq.pendingResponses, common.CompletionResponse{Content: piece}) + sequence := "" + for _, r := range seq.pendingResponses { + sequence += r.Content + } if ok, stop := common.FindStop(sequence, seq.stop); ok { slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop) diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 9a1a549cd..11385a8a0 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -53,13 +53,13 @@ type Sequence struct { pendingInputs []input.Input // tokens that have been generated but not returned yet (e.g. for stop sequences) - pendingResponses []string + pendingResponses []common.CompletionResponse // input cache being used by this sequence cache *InputCacheSlot // channel to send responses over - responses chan string + responses chan common.CompletionResponse // channel to stop decoding (such as if the remote connection is closed) quit chan bool @@ -138,8 +138,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe numPromptInputs: len(inputs), startProcessingTime: startTime, numPredict: params.numPredict, - pendingResponses: make([]string, 0), - responses: make(chan string, 100), + pendingResponses: make([]common.CompletionResponse, 0), + responses: make(chan common.CompletionResponse, 100), quit: make(chan bool, 1), embedding: make(chan []float32, 1), sampler: params.sampler, @@ -288,29 +288,28 @@ func (s *Server) allNil() bool { } func flushPending(seq *Sequence) bool { - joined := strings.Join(seq.pendingResponses, "") - seq.pendingResponses = []string{} + pending := seq.pendingResponses + seq.pendingResponses = []common.CompletionResponse{} - // Check if there are any partial UTF-8 characters remaining. - // We already check and queue as we are generating but some may - // still make it here: - // - Sequence is ending, e.g. generation limit has been hit - // - Invalid characters in the middle of a string - // This is a stricter check to ensure we never output invalid Unicode. - for !utf8.ValidString(joined) { - joined = joined[:len(joined)-1] - } + for i, r := range pending { + if i == len(pending)-1 { + // Check and trim any trailing partial UTF-8 characters + content := r.Content + for !utf8.ValidString(content) { + content = content[:len(content)-1] + } + r.Content = content + } - if len(joined) == 0 { - return true - } - - select { - case seq.responses <- joined: - return true - case <-seq.quit: - return false + select { + case seq.responses <- r: + return true + case <-seq.quit: + return false + } } + // no pending responses to send + return true } func (s *Server) removeSequence(seqIndex int, reason string) { @@ -484,8 +483,11 @@ func (s *Server) processBatch() error { seq.inputs = []input.Input{{Token: token}} - seq.pendingResponses = append(seq.pendingResponses, piece) - sequence := strings.Join(seq.pendingResponses, "") + seq.pendingResponses = append(seq.pendingResponses, common.CompletionResponse{Content: piece}) + sequence := "" + for _, r := range seq.pendingResponses { + sequence += r.Content + } if ok, stop := common.FindStop(sequence, seq.stop); ok { slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)