update completion responses

This commit is contained in:
Bruce MacDonald 2025-03-19 10:00:00 -07:00
parent 905da35468
commit 946fdd5388
5 changed files with 34 additions and 57 deletions

View File

@ -2,6 +2,8 @@ package common
import ( import (
"strings" "strings"
"github.com/ollama/ollama/llm"
) )
func FindStop(sequence string, stops []string) (bool, string) { func FindStop(sequence string, stops []string) (bool, string) {
@ -29,7 +31,7 @@ func ContainsStopSuffix(sequence string, stops []string) bool {
// truncateStop removes the provided stop string from pieces, // truncateStop removes the provided stop string from pieces,
// returning the partial pieces with stop removed, including truncating // returning the partial pieces with stop removed, including truncating
// the last piece if required (and signalling if this was the case) // the last piece if required (and signalling if this was the case)
func TruncateStop(resps []CompletionResponse, stop string) ([]CompletionResponse, bool) { func TruncateStop(resps []llm.CompletionResponse, stop string) ([]llm.CompletionResponse, bool) {
var sequence string var sequence string
for _, resp := range resps { for _, resp := range resps {
sequence += resp.Content sequence += resp.Content
@ -45,7 +47,7 @@ func TruncateStop(resps []CompletionResponse, stop string) ([]CompletionResponse
return nil, true return nil, true
} }
result := make([]CompletionResponse, 0, len(resps)) result := make([]llm.CompletionResponse, 0, len(resps))
// Track position in truncated sequence // Track position in truncated sequence
pos := 0 pos := 0
@ -60,7 +62,7 @@ func TruncateStop(resps []CompletionResponse, stop string) ([]CompletionResponse
truncationHappened = true truncationHappened = true
} }
if len(chunk) > 0 { if len(chunk) > 0 {
result = append(result, CompletionResponse{Content: chunk}) result = append(result, llm.CompletionResponse{Content: chunk})
} }
pos += len(resp.Content) pos += len(resp.Content)
} }

View File

@ -4,36 +4,38 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"testing" "testing"
"github.com/ollama/ollama/llm"
) )
func TestTruncateStop(t *testing.T) { func TestTruncateStop(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
pieces []CompletionResponse pieces []llm.CompletionResponse
stop string stop string
expected []CompletionResponse expected []llm.CompletionResponse
expectedTrunc bool expectedTrunc bool
}{ }{
{ {
name: "Single word", name: "Single word",
pieces: []CompletionResponse{ pieces: []llm.CompletionResponse{
{Content: "Hello"}, {Content: "Hello"},
{Content: "world"}, {Content: "world"},
}, },
stop: "world", stop: "world",
expected: []CompletionResponse{ expected: []llm.CompletionResponse{
{Content: "Hello"}, {Content: "Hello"},
}, },
expectedTrunc: false, expectedTrunc: false,
}, },
{ {
name: "Partial", name: "Partial",
pieces: []CompletionResponse{ pieces: []llm.CompletionResponse{
{Content: "Hello"}, {Content: "Hello"},
{Content: " wor"}, {Content: " wor"},
}, },
stop: "or", stop: "or",
expected: []CompletionResponse{ expected: []llm.CompletionResponse{
{Content: "Hello"}, {Content: "Hello"},
{Content: " w"}, {Content: " w"},
}, },
@ -41,13 +43,13 @@ func TestTruncateStop(t *testing.T) {
}, },
{ {
name: "Suffix", name: "Suffix",
pieces: []CompletionResponse{ pieces: []llm.CompletionResponse{
{Content: "Hello"}, {Content: "Hello"},
{Content: " there"}, {Content: " there"},
{Content: "!"}, {Content: "!"},
}, },
stop: "!", stop: "!",
expected: []CompletionResponse{ expected: []llm.CompletionResponse{
{Content: "Hello"}, {Content: "Hello"},
{Content: " there"}, {Content: " there"},
}, },
@ -55,13 +57,13 @@ func TestTruncateStop(t *testing.T) {
}, },
{ {
name: "Suffix partial", name: "Suffix partial",
pieces: []CompletionResponse{ pieces: []llm.CompletionResponse{
{Content: "Hello"}, {Content: "Hello"},
{Content: " the"}, {Content: " the"},
{Content: "re!"}, {Content: "re!"},
}, },
stop: "there!", stop: "there!",
expected: []CompletionResponse{ expected: []llm.CompletionResponse{
{Content: "Hello"}, {Content: "Hello"},
{Content: " "}, {Content: " "},
}, },
@ -69,12 +71,12 @@ func TestTruncateStop(t *testing.T) {
}, },
{ {
name: "Middle", name: "Middle",
pieces: []CompletionResponse{ pieces: []llm.CompletionResponse{
{Content: "Hello"}, {Content: "Hello"},
{Content: " wo"}, {Content: " wo"},
}, },
stop: "llo w", stop: "llo w",
expected: []CompletionResponse{ expected: []llm.CompletionResponse{
{Content: "He"}, {Content: "He"},
}, },
expectedTrunc: true, expectedTrunc: true,
@ -92,7 +94,7 @@ func TestTruncateStop(t *testing.T) {
} }
} }
func formatContentDiff(result, expected []CompletionResponse) string { func formatContentDiff(result, expected []llm.CompletionResponse) string {
var s string var s string
for i := 0; i < len(result) || i < len(expected); i++ { for i := 0; i < len(result) || i < len(expected); i++ {
if i < len(result) && i < len(expected) && result[i].Content != expected[i].Content { if i < len(result) && i < len(expected) && result[i].Content != expected[i].Content {

View File

@ -1,23 +0,0 @@
package common
type CompletionResponse struct {
Content string `json:"content"`
Stop bool `json:"stop"`
Model string `json:"model,omitempty"`
Prompt string `json:"prompt,omitempty"`
StoppedLimit bool `json:"stopped_limit,omitempty"`
PredictedN int `json:"predicted_n,omitempty"`
PredictedMS float64 `json:"predicted_ms,omitempty"`
PromptN int `json:"prompt_n,omitempty"`
PromptMS float64 `json:"prompt_ms,omitempty"`
Timings Timings `json:"timings"`
}
type Timings struct {
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
}

View File

@ -51,7 +51,7 @@ type Sequence struct {
pendingInputs []input pendingInputs []input
// tokens that have been generated but not returned yet (e.g. for stop sequences) // tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []common.CompletionResponse pendingResponses []llm.CompletionResponse
// input cache being used by this sequence // input cache being used by this sequence
cache *InputCacheSlot cache *InputCacheSlot
@ -61,7 +61,7 @@ type Sequence struct {
crossAttention bool crossAttention bool
// channel to send responses over // channel to send responses over
responses chan common.CompletionResponse responses chan llm.CompletionResponse
// channel to stop decoding (such as if the remote connection is closed) // channel to stop decoding (such as if the remote connection is closed)
quit chan bool quit chan bool
@ -150,8 +150,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
numPromptInputs: len(inputs), numPromptInputs: len(inputs),
startProcessingTime: startTime, startProcessingTime: startTime,
numPredict: params.numPredict, numPredict: params.numPredict,
pendingResponses: make([]common.CompletionResponse, 0), pendingResponses: make([]llm.CompletionResponse, 0),
responses: make(chan common.CompletionResponse, 100), responses: make(chan llm.CompletionResponse, 100),
quit: make(chan bool, 1), quit: make(chan bool, 1),
embedding: make(chan []float32, 1), embedding: make(chan []float32, 1),
samplingCtx: sc, samplingCtx: sc,
@ -277,7 +277,7 @@ func (s *Server) allNil() bool {
func flushPending(seq *Sequence) bool { func flushPending(seq *Sequence) bool {
pending := seq.pendingResponses pending := seq.pendingResponses
seq.pendingResponses = []common.CompletionResponse{} seq.pendingResponses = []llm.CompletionResponse{}
for i, r := range pending { for i, r := range pending {
if i == len(pending)-1 { if i == len(pending)-1 {
@ -496,7 +496,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
seq.inputs = []input{{token: token}} seq.inputs = []input{{token: token}}
seq.pendingResponses = append(seq.pendingResponses, common.CompletionResponse{Content: piece}) seq.pendingResponses = append(seq.pendingResponses, llm.CompletionResponse{Content: piece})
sequence := "" sequence := ""
for _, r := range seq.pendingResponses { for _, r := range seq.pendingResponses {
sequence += r.Content sequence += r.Content
@ -639,9 +639,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
case content, ok := <-seq.responses: case content, ok := <-seq.responses:
if ok { if ok {
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ if err := json.NewEncoder(w).Encode(&content); err != nil {
Content: content,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
close(seq.quit) close(seq.quit)
return return

View File

@ -53,13 +53,13 @@ type Sequence struct {
pendingInputs []input.Input pendingInputs []input.Input
// tokens that have been generated but not returned yet (e.g. for stop sequences) // tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []common.CompletionResponse pendingResponses []llm.CompletionResponse
// input cache being used by this sequence // input cache being used by this sequence
cache *InputCacheSlot cache *InputCacheSlot
// channel to send responses over // channel to send responses over
responses chan common.CompletionResponse responses chan llm.CompletionResponse
// channel to stop decoding (such as if the remote connection is closed) // channel to stop decoding (such as if the remote connection is closed)
quit chan bool quit chan bool
@ -138,8 +138,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
numPromptInputs: len(inputs), numPromptInputs: len(inputs),
startProcessingTime: startTime, startProcessingTime: startTime,
numPredict: params.numPredict, numPredict: params.numPredict,
pendingResponses: make([]common.CompletionResponse, 0), pendingResponses: make([]llm.CompletionResponse, 0),
responses: make(chan common.CompletionResponse, 100), responses: make(chan llm.CompletionResponse, 100),
quit: make(chan bool, 1), quit: make(chan bool, 1),
embedding: make(chan []float32, 1), embedding: make(chan []float32, 1),
sampler: params.sampler, sampler: params.sampler,
@ -289,7 +289,7 @@ func (s *Server) allNil() bool {
func flushPending(seq *Sequence) bool { func flushPending(seq *Sequence) bool {
pending := seq.pendingResponses pending := seq.pendingResponses
seq.pendingResponses = []common.CompletionResponse{} seq.pendingResponses = []llm.CompletionResponse{}
for i, r := range pending { for i, r := range pending {
if i == len(pending)-1 { if i == len(pending)-1 {
@ -483,7 +483,7 @@ func (s *Server) processBatch() error {
seq.inputs = []input.Input{{Token: token}} seq.inputs = []input.Input{{Token: token}}
seq.pendingResponses = append(seq.pendingResponses, common.CompletionResponse{Content: piece}) seq.pendingResponses = append(seq.pendingResponses, llm.CompletionResponse{Content: piece})
sequence := "" sequence := ""
for _, r := range seq.pendingResponses { for _, r := range seq.pendingResponses {
sequence += r.Content sequence += r.Content
@ -625,9 +625,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
case content, ok := <-seq.responses: case content, ok := <-seq.responses:
if ok { if ok {
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ if err := json.NewEncoder(w).Encode(&content); err != nil {
Content: content,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
close(seq.quit) close(seq.quit)
return return