update completion responses
This commit is contained in:
parent
905da35468
commit
946fdd5388
@ -2,6 +2,8 @@ package common
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func FindStop(sequence string, stops []string) (bool, string) {
|
func FindStop(sequence string, stops []string) (bool, string) {
|
||||||
@ -29,7 +31,7 @@ func ContainsStopSuffix(sequence string, stops []string) bool {
|
|||||||
// truncateStop removes the provided stop string from pieces,
|
// truncateStop removes the provided stop string from pieces,
|
||||||
// returning the partial pieces with stop removed, including truncating
|
// returning the partial pieces with stop removed, including truncating
|
||||||
// the last piece if required (and signalling if this was the case)
|
// the last piece if required (and signalling if this was the case)
|
||||||
func TruncateStop(resps []CompletionResponse, stop string) ([]CompletionResponse, bool) {
|
func TruncateStop(resps []llm.CompletionResponse, stop string) ([]llm.CompletionResponse, bool) {
|
||||||
var sequence string
|
var sequence string
|
||||||
for _, resp := range resps {
|
for _, resp := range resps {
|
||||||
sequence += resp.Content
|
sequence += resp.Content
|
||||||
@ -45,7 +47,7 @@ func TruncateStop(resps []CompletionResponse, stop string) ([]CompletionResponse
|
|||||||
return nil, true
|
return nil, true
|
||||||
}
|
}
|
||||||
|
|
||||||
result := make([]CompletionResponse, 0, len(resps))
|
result := make([]llm.CompletionResponse, 0, len(resps))
|
||||||
|
|
||||||
// Track position in truncated sequence
|
// Track position in truncated sequence
|
||||||
pos := 0
|
pos := 0
|
||||||
@ -60,7 +62,7 @@ func TruncateStop(resps []CompletionResponse, stop string) ([]CompletionResponse
|
|||||||
truncationHappened = true
|
truncationHappened = true
|
||||||
}
|
}
|
||||||
if len(chunk) > 0 {
|
if len(chunk) > 0 {
|
||||||
result = append(result, CompletionResponse{Content: chunk})
|
result = append(result, llm.CompletionResponse{Content: chunk})
|
||||||
}
|
}
|
||||||
pos += len(resp.Content)
|
pos += len(resp.Content)
|
||||||
}
|
}
|
||||||
|
@ -4,36 +4,38 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestTruncateStop(t *testing.T) {
|
func TestTruncateStop(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
pieces []CompletionResponse
|
pieces []llm.CompletionResponse
|
||||||
stop string
|
stop string
|
||||||
expected []CompletionResponse
|
expected []llm.CompletionResponse
|
||||||
expectedTrunc bool
|
expectedTrunc bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Single word",
|
name: "Single word",
|
||||||
pieces: []CompletionResponse{
|
pieces: []llm.CompletionResponse{
|
||||||
{Content: "Hello"},
|
{Content: "Hello"},
|
||||||
{Content: "world"},
|
{Content: "world"},
|
||||||
},
|
},
|
||||||
stop: "world",
|
stop: "world",
|
||||||
expected: []CompletionResponse{
|
expected: []llm.CompletionResponse{
|
||||||
{Content: "Hello"},
|
{Content: "Hello"},
|
||||||
},
|
},
|
||||||
expectedTrunc: false,
|
expectedTrunc: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Partial",
|
name: "Partial",
|
||||||
pieces: []CompletionResponse{
|
pieces: []llm.CompletionResponse{
|
||||||
{Content: "Hello"},
|
{Content: "Hello"},
|
||||||
{Content: " wor"},
|
{Content: " wor"},
|
||||||
},
|
},
|
||||||
stop: "or",
|
stop: "or",
|
||||||
expected: []CompletionResponse{
|
expected: []llm.CompletionResponse{
|
||||||
{Content: "Hello"},
|
{Content: "Hello"},
|
||||||
{Content: " w"},
|
{Content: " w"},
|
||||||
},
|
},
|
||||||
@ -41,13 +43,13 @@ func TestTruncateStop(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Suffix",
|
name: "Suffix",
|
||||||
pieces: []CompletionResponse{
|
pieces: []llm.CompletionResponse{
|
||||||
{Content: "Hello"},
|
{Content: "Hello"},
|
||||||
{Content: " there"},
|
{Content: " there"},
|
||||||
{Content: "!"},
|
{Content: "!"},
|
||||||
},
|
},
|
||||||
stop: "!",
|
stop: "!",
|
||||||
expected: []CompletionResponse{
|
expected: []llm.CompletionResponse{
|
||||||
{Content: "Hello"},
|
{Content: "Hello"},
|
||||||
{Content: " there"},
|
{Content: " there"},
|
||||||
},
|
},
|
||||||
@ -55,13 +57,13 @@ func TestTruncateStop(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Suffix partial",
|
name: "Suffix partial",
|
||||||
pieces: []CompletionResponse{
|
pieces: []llm.CompletionResponse{
|
||||||
{Content: "Hello"},
|
{Content: "Hello"},
|
||||||
{Content: " the"},
|
{Content: " the"},
|
||||||
{Content: "re!"},
|
{Content: "re!"},
|
||||||
},
|
},
|
||||||
stop: "there!",
|
stop: "there!",
|
||||||
expected: []CompletionResponse{
|
expected: []llm.CompletionResponse{
|
||||||
{Content: "Hello"},
|
{Content: "Hello"},
|
||||||
{Content: " "},
|
{Content: " "},
|
||||||
},
|
},
|
||||||
@ -69,12 +71,12 @@ func TestTruncateStop(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Middle",
|
name: "Middle",
|
||||||
pieces: []CompletionResponse{
|
pieces: []llm.CompletionResponse{
|
||||||
{Content: "Hello"},
|
{Content: "Hello"},
|
||||||
{Content: " wo"},
|
{Content: " wo"},
|
||||||
},
|
},
|
||||||
stop: "llo w",
|
stop: "llo w",
|
||||||
expected: []CompletionResponse{
|
expected: []llm.CompletionResponse{
|
||||||
{Content: "He"},
|
{Content: "He"},
|
||||||
},
|
},
|
||||||
expectedTrunc: true,
|
expectedTrunc: true,
|
||||||
@ -92,7 +94,7 @@ func TestTruncateStop(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func formatContentDiff(result, expected []CompletionResponse) string {
|
func formatContentDiff(result, expected []llm.CompletionResponse) string {
|
||||||
var s string
|
var s string
|
||||||
for i := 0; i < len(result) || i < len(expected); i++ {
|
for i := 0; i < len(result) || i < len(expected); i++ {
|
||||||
if i < len(result) && i < len(expected) && result[i].Content != expected[i].Content {
|
if i < len(result) && i < len(expected) && result[i].Content != expected[i].Content {
|
||||||
|
@ -1,23 +0,0 @@
|
|||||||
package common
|
|
||||||
|
|
||||||
type CompletionResponse struct {
|
|
||||||
Content string `json:"content"`
|
|
||||||
Stop bool `json:"stop"`
|
|
||||||
|
|
||||||
Model string `json:"model,omitempty"`
|
|
||||||
Prompt string `json:"prompt,omitempty"`
|
|
||||||
StoppedLimit bool `json:"stopped_limit,omitempty"`
|
|
||||||
PredictedN int `json:"predicted_n,omitempty"`
|
|
||||||
PredictedMS float64 `json:"predicted_ms,omitempty"`
|
|
||||||
PromptN int `json:"prompt_n,omitempty"`
|
|
||||||
PromptMS float64 `json:"prompt_ms,omitempty"`
|
|
||||||
|
|
||||||
Timings Timings `json:"timings"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Timings struct {
|
|
||||||
PredictedN int `json:"predicted_n"`
|
|
||||||
PredictedMS float64 `json:"predicted_ms"`
|
|
||||||
PromptN int `json:"prompt_n"`
|
|
||||||
PromptMS float64 `json:"prompt_ms"`
|
|
||||||
}
|
|
@ -51,7 +51,7 @@ type Sequence struct {
|
|||||||
pendingInputs []input
|
pendingInputs []input
|
||||||
|
|
||||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||||
pendingResponses []common.CompletionResponse
|
pendingResponses []llm.CompletionResponse
|
||||||
|
|
||||||
// input cache being used by this sequence
|
// input cache being used by this sequence
|
||||||
cache *InputCacheSlot
|
cache *InputCacheSlot
|
||||||
@ -61,7 +61,7 @@ type Sequence struct {
|
|||||||
crossAttention bool
|
crossAttention bool
|
||||||
|
|
||||||
// channel to send responses over
|
// channel to send responses over
|
||||||
responses chan common.CompletionResponse
|
responses chan llm.CompletionResponse
|
||||||
|
|
||||||
// channel to stop decoding (such as if the remote connection is closed)
|
// channel to stop decoding (such as if the remote connection is closed)
|
||||||
quit chan bool
|
quit chan bool
|
||||||
@ -150,8 +150,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
|||||||
numPromptInputs: len(inputs),
|
numPromptInputs: len(inputs),
|
||||||
startProcessingTime: startTime,
|
startProcessingTime: startTime,
|
||||||
numPredict: params.numPredict,
|
numPredict: params.numPredict,
|
||||||
pendingResponses: make([]common.CompletionResponse, 0),
|
pendingResponses: make([]llm.CompletionResponse, 0),
|
||||||
responses: make(chan common.CompletionResponse, 100),
|
responses: make(chan llm.CompletionResponse, 100),
|
||||||
quit: make(chan bool, 1),
|
quit: make(chan bool, 1),
|
||||||
embedding: make(chan []float32, 1),
|
embedding: make(chan []float32, 1),
|
||||||
samplingCtx: sc,
|
samplingCtx: sc,
|
||||||
@ -277,7 +277,7 @@ func (s *Server) allNil() bool {
|
|||||||
|
|
||||||
func flushPending(seq *Sequence) bool {
|
func flushPending(seq *Sequence) bool {
|
||||||
pending := seq.pendingResponses
|
pending := seq.pendingResponses
|
||||||
seq.pendingResponses = []common.CompletionResponse{}
|
seq.pendingResponses = []llm.CompletionResponse{}
|
||||||
|
|
||||||
for i, r := range pending {
|
for i, r := range pending {
|
||||||
if i == len(pending)-1 {
|
if i == len(pending)-1 {
|
||||||
@ -496,7 +496,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||||||
|
|
||||||
seq.inputs = []input{{token: token}}
|
seq.inputs = []input{{token: token}}
|
||||||
|
|
||||||
seq.pendingResponses = append(seq.pendingResponses, common.CompletionResponse{Content: piece})
|
seq.pendingResponses = append(seq.pendingResponses, llm.CompletionResponse{Content: piece})
|
||||||
sequence := ""
|
sequence := ""
|
||||||
for _, r := range seq.pendingResponses {
|
for _, r := range seq.pendingResponses {
|
||||||
sequence += r.Content
|
sequence += r.Content
|
||||||
@ -639,9 +639,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
case content, ok := <-seq.responses:
|
case content, ok := <-seq.responses:
|
||||||
if ok {
|
if ok {
|
||||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
if err := json.NewEncoder(w).Encode(&content); err != nil {
|
||||||
Content: content,
|
|
||||||
}); err != nil {
|
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
close(seq.quit)
|
close(seq.quit)
|
||||||
return
|
return
|
||||||
|
@ -53,13 +53,13 @@ type Sequence struct {
|
|||||||
pendingInputs []input.Input
|
pendingInputs []input.Input
|
||||||
|
|
||||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||||
pendingResponses []common.CompletionResponse
|
pendingResponses []llm.CompletionResponse
|
||||||
|
|
||||||
// input cache being used by this sequence
|
// input cache being used by this sequence
|
||||||
cache *InputCacheSlot
|
cache *InputCacheSlot
|
||||||
|
|
||||||
// channel to send responses over
|
// channel to send responses over
|
||||||
responses chan common.CompletionResponse
|
responses chan llm.CompletionResponse
|
||||||
|
|
||||||
// channel to stop decoding (such as if the remote connection is closed)
|
// channel to stop decoding (such as if the remote connection is closed)
|
||||||
quit chan bool
|
quit chan bool
|
||||||
@ -138,8 +138,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
|||||||
numPromptInputs: len(inputs),
|
numPromptInputs: len(inputs),
|
||||||
startProcessingTime: startTime,
|
startProcessingTime: startTime,
|
||||||
numPredict: params.numPredict,
|
numPredict: params.numPredict,
|
||||||
pendingResponses: make([]common.CompletionResponse, 0),
|
pendingResponses: make([]llm.CompletionResponse, 0),
|
||||||
responses: make(chan common.CompletionResponse, 100),
|
responses: make(chan llm.CompletionResponse, 100),
|
||||||
quit: make(chan bool, 1),
|
quit: make(chan bool, 1),
|
||||||
embedding: make(chan []float32, 1),
|
embedding: make(chan []float32, 1),
|
||||||
sampler: params.sampler,
|
sampler: params.sampler,
|
||||||
@ -289,7 +289,7 @@ func (s *Server) allNil() bool {
|
|||||||
|
|
||||||
func flushPending(seq *Sequence) bool {
|
func flushPending(seq *Sequence) bool {
|
||||||
pending := seq.pendingResponses
|
pending := seq.pendingResponses
|
||||||
seq.pendingResponses = []common.CompletionResponse{}
|
seq.pendingResponses = []llm.CompletionResponse{}
|
||||||
|
|
||||||
for i, r := range pending {
|
for i, r := range pending {
|
||||||
if i == len(pending)-1 {
|
if i == len(pending)-1 {
|
||||||
@ -483,7 +483,7 @@ func (s *Server) processBatch() error {
|
|||||||
|
|
||||||
seq.inputs = []input.Input{{Token: token}}
|
seq.inputs = []input.Input{{Token: token}}
|
||||||
|
|
||||||
seq.pendingResponses = append(seq.pendingResponses, common.CompletionResponse{Content: piece})
|
seq.pendingResponses = append(seq.pendingResponses, llm.CompletionResponse{Content: piece})
|
||||||
sequence := ""
|
sequence := ""
|
||||||
for _, r := range seq.pendingResponses {
|
for _, r := range seq.pendingResponses {
|
||||||
sequence += r.Content
|
sequence += r.Content
|
||||||
@ -625,9 +625,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
case content, ok := <-seq.responses:
|
case content, ok := <-seq.responses:
|
||||||
if ok {
|
if ok {
|
||||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
if err := json.NewEncoder(w).Encode(&content); err != nil {
|
||||||
Content: content,
|
|
||||||
}); err != nil {
|
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
close(seq.quit)
|
close(seq.quit)
|
||||||
return
|
return
|
||||||
|
Loading…
x
Reference in New Issue
Block a user