diff --git a/runner/common/stop.go b/runner/common/stop.go index 748f3f0b7..79cdd9e68 100644 --- a/runner/common/stop.go +++ b/runner/common/stop.go @@ -2,6 +2,8 @@ package common import ( "strings" + + "github.com/ollama/ollama/llm" ) func FindStop(sequence string, stops []string) (bool, string) { @@ -29,7 +31,7 @@ func ContainsStopSuffix(sequence string, stops []string) bool { // truncateStop removes the provided stop string from pieces, // returning the partial pieces with stop removed, including truncating // the last piece if required (and signalling if this was the case) -func TruncateStop(resps []CompletionResponse, stop string) ([]CompletionResponse, bool) { +func TruncateStop(resps []llm.CompletionResponse, stop string) ([]llm.CompletionResponse, bool) { var sequence string for _, resp := range resps { sequence += resp.Content @@ -45,7 +47,7 @@ func TruncateStop(resps []CompletionResponse, stop string) ([]CompletionResponse return nil, true } - result := make([]CompletionResponse, 0, len(resps)) + result := make([]llm.CompletionResponse, 0, len(resps)) // Track position in truncated sequence pos := 0 @@ -60,7 +62,7 @@ func TruncateStop(resps []CompletionResponse, stop string) ([]CompletionResponse truncationHappened = true } if len(chunk) > 0 { - result = append(result, CompletionResponse{Content: chunk}) + result = append(result, llm.CompletionResponse{Content: chunk}) } pos += len(resp.Content) } diff --git a/runner/common/stop_test.go b/runner/common/stop_test.go index df68172fa..dc0e7ac72 100644 --- a/runner/common/stop_test.go +++ b/runner/common/stop_test.go @@ -4,36 +4,38 @@ import ( "fmt" "reflect" "testing" + + "github.com/ollama/ollama/llm" ) func TestTruncateStop(t *testing.T) { tests := []struct { name string - pieces []CompletionResponse + pieces []llm.CompletionResponse stop string - expected []CompletionResponse + expected []llm.CompletionResponse expectedTrunc bool }{ { name: "Single word", - pieces: []CompletionResponse{ + pieces: []llm.CompletionResponse{ {Content: "Hello"}, {Content: "world"}, }, stop: "world", - expected: []CompletionResponse{ + expected: []llm.CompletionResponse{ {Content: "Hello"}, }, expectedTrunc: false, }, { name: "Partial", - pieces: []CompletionResponse{ + pieces: []llm.CompletionResponse{ {Content: "Hello"}, {Content: " wor"}, }, stop: "or", - expected: []CompletionResponse{ + expected: []llm.CompletionResponse{ {Content: "Hello"}, {Content: " w"}, }, @@ -41,13 +43,13 @@ func TestTruncateStop(t *testing.T) { }, { name: "Suffix", - pieces: []CompletionResponse{ + pieces: []llm.CompletionResponse{ {Content: "Hello"}, {Content: " there"}, {Content: "!"}, }, stop: "!", - expected: []CompletionResponse{ + expected: []llm.CompletionResponse{ {Content: "Hello"}, {Content: " there"}, }, @@ -55,13 +57,13 @@ func TestTruncateStop(t *testing.T) { }, { name: "Suffix partial", - pieces: []CompletionResponse{ + pieces: []llm.CompletionResponse{ {Content: "Hello"}, {Content: " the"}, {Content: "re!"}, }, stop: "there!", - expected: []CompletionResponse{ + expected: []llm.CompletionResponse{ {Content: "Hello"}, {Content: " "}, }, @@ -69,12 +71,12 @@ func TestTruncateStop(t *testing.T) { }, { name: "Middle", - pieces: []CompletionResponse{ + pieces: []llm.CompletionResponse{ {Content: "Hello"}, {Content: " wo"}, }, stop: "llo w", - expected: []CompletionResponse{ + expected: []llm.CompletionResponse{ {Content: "He"}, }, expectedTrunc: true, @@ -92,7 +94,7 @@ func TestTruncateStop(t *testing.T) { } } -func formatContentDiff(result, expected []CompletionResponse) string { +func formatContentDiff(result, expected []llm.CompletionResponse) string { var s string for i := 0; i < len(result) || i < len(expected); i++ { if i < len(result) && i < len(expected) && result[i].Content != expected[i].Content { diff --git a/runner/common/types.go b/runner/common/types.go deleted file mode 100644 index 51487540a..000000000 --- a/runner/common/types.go +++ /dev/null @@ -1,23 +0,0 @@ -package common - -type CompletionResponse struct { - Content string `json:"content"` - Stop bool `json:"stop"` - - Model string `json:"model,omitempty"` - Prompt string `json:"prompt,omitempty"` - StoppedLimit bool `json:"stopped_limit,omitempty"` - PredictedN int `json:"predicted_n,omitempty"` - PredictedMS float64 `json:"predicted_ms,omitempty"` - PromptN int `json:"prompt_n,omitempty"` - PromptMS float64 `json:"prompt_ms,omitempty"` - - Timings Timings `json:"timings"` -} - -type Timings struct { - PredictedN int `json:"predicted_n"` - PredictedMS float64 `json:"predicted_ms"` - PromptN int `json:"prompt_n"` - PromptMS float64 `json:"prompt_ms"` -} diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index accbfddb3..6dd67bfd5 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -51,7 +51,7 @@ type Sequence struct { pendingInputs []input // tokens that have been generated but not returned yet (e.g. for stop sequences) - pendingResponses []common.CompletionResponse + pendingResponses []llm.CompletionResponse // input cache being used by this sequence cache *InputCacheSlot @@ -61,7 +61,7 @@ type Sequence struct { crossAttention bool // channel to send responses over - responses chan common.CompletionResponse + responses chan llm.CompletionResponse // channel to stop decoding (such as if the remote connection is closed) quit chan bool @@ -150,8 +150,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe numPromptInputs: len(inputs), startProcessingTime: startTime, numPredict: params.numPredict, - pendingResponses: make([]common.CompletionResponse, 0), - responses: make(chan common.CompletionResponse, 100), + pendingResponses: make([]llm.CompletionResponse, 0), + responses: make(chan llm.CompletionResponse, 100), quit: make(chan bool, 1), embedding: make(chan []float32, 1), samplingCtx: sc, @@ -277,7 +277,7 @@ func (s *Server) allNil() bool { func flushPending(seq *Sequence) bool { pending := seq.pendingResponses - seq.pendingResponses = []common.CompletionResponse{} + seq.pendingResponses = []llm.CompletionResponse{} for i, r := range pending { if i == len(pending)-1 { @@ -496,7 +496,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) seq.inputs = []input{{token: token}} - seq.pendingResponses = append(seq.pendingResponses, common.CompletionResponse{Content: piece}) + seq.pendingResponses = append(seq.pendingResponses, llm.CompletionResponse{Content: piece}) sequence := "" for _, r := range seq.pendingResponses { sequence += r.Content @@ -639,9 +639,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return case content, ok := <-seq.responses: if ok { - if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ - Content: content, - }); err != nil { + if err := json.NewEncoder(w).Encode(&content); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) close(seq.quit) return diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 11385a8a0..3b7c5c570 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -53,13 +53,13 @@ type Sequence struct { pendingInputs []input.Input // tokens that have been generated but not returned yet (e.g. for stop sequences) - pendingResponses []common.CompletionResponse + pendingResponses []llm.CompletionResponse // input cache being used by this sequence cache *InputCacheSlot // channel to send responses over - responses chan common.CompletionResponse + responses chan llm.CompletionResponse // channel to stop decoding (such as if the remote connection is closed) quit chan bool @@ -138,8 +138,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe numPromptInputs: len(inputs), startProcessingTime: startTime, numPredict: params.numPredict, - pendingResponses: make([]common.CompletionResponse, 0), - responses: make(chan common.CompletionResponse, 100), + pendingResponses: make([]llm.CompletionResponse, 0), + responses: make(chan llm.CompletionResponse, 100), quit: make(chan bool, 1), embedding: make(chan []float32, 1), sampler: params.sampler, @@ -289,7 +289,7 @@ func (s *Server) allNil() bool { func flushPending(seq *Sequence) bool { pending := seq.pendingResponses - seq.pendingResponses = []common.CompletionResponse{} + seq.pendingResponses = []llm.CompletionResponse{} for i, r := range pending { if i == len(pending)-1 { @@ -483,7 +483,7 @@ func (s *Server) processBatch() error { seq.inputs = []input.Input{{Token: token}} - seq.pendingResponses = append(seq.pendingResponses, common.CompletionResponse{Content: piece}) + seq.pendingResponses = append(seq.pendingResponses, llm.CompletionResponse{Content: piece}) sequence := "" for _, r := range seq.pendingResponses { sequence += r.Content @@ -625,9 +625,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return case content, ok := <-seq.responses: if ok { - if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ - Content: content, - }); err != nil { + if err := json.NewEncoder(w).Encode(&content); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) close(seq.quit) return