prototype
This commit is contained in:
parent
64f95067ba
commit
fdbb0b5cfe
21
api/types.go
21
api/types.go
@ -77,6 +77,8 @@ type GenerateRequest struct {
|
|||||||
// request, for multimodal models.
|
// request, for multimodal models.
|
||||||
Images []ImageData `json:"images,omitempty"`
|
Images []ImageData `json:"images,omitempty"`
|
||||||
|
|
||||||
|
LogProbs int `json:"logprobs,omitempty"`
|
||||||
|
|
||||||
// Options lists model-specific options. For example, temperature can be
|
// Options lists model-specific options. For example, temperature can be
|
||||||
// set through this field, if the model supports it.
|
// set through this field, if the model supports it.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
@ -103,6 +105,8 @@ type ChatRequest struct {
|
|||||||
// Tools is an optional list of tools the model has access to.
|
// Tools is an optional list of tools the model has access to.
|
||||||
Tools `json:"tools,omitempty"`
|
Tools `json:"tools,omitempty"`
|
||||||
|
|
||||||
|
LogProbs int `json:"logprobs,omitempty"`
|
||||||
|
|
||||||
// Options lists model-specific options.
|
// Options lists model-specific options.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
}
|
}
|
||||||
@ -182,13 +186,20 @@ func (t *ToolFunction) String() string {
|
|||||||
return string(bts)
|
return string(bts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TokenProbs struct {
|
||||||
|
TokenID int `json:"id"`
|
||||||
|
LogProb float32 `json:"logprob"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
}
|
||||||
|
|
||||||
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
||||||
// similar to [GenerateResponse].
|
// similar to [GenerateResponse].
|
||||||
type ChatResponse struct {
|
type ChatResponse struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
Message Message `json:"message"`
|
Message Message `json:"message"`
|
||||||
DoneReason string `json:"done_reason,omitempty"`
|
DoneReason string `json:"done_reason,omitempty"`
|
||||||
|
LogProbs []TokenProbs `json:"logprobs,omitempty"`
|
||||||
|
|
||||||
Done bool `json:"done"`
|
Done bool `json:"done"`
|
||||||
|
|
||||||
@ -452,6 +463,8 @@ type GenerateResponse struct {
|
|||||||
// can be sent in the next request to keep a conversational memory.
|
// can be sent in the next request to keep a conversational memory.
|
||||||
Context []int `json:"context,omitempty"`
|
Context []int `json:"context,omitempty"`
|
||||||
|
|
||||||
|
LogProbs []TokenProbs `json:"logprobs,omitempty"`
|
||||||
|
|
||||||
Metrics
|
Metrics
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -233,18 +233,6 @@ func (c *Context) GetLogits() []float32 {
|
|||||||
return unsafe.Slice((*float32)(logits), vocabSize)
|
return unsafe.Slice((*float32)(logits), vocabSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Detokenize(tokens []int) (string, error) {
|
|
||||||
var text string
|
|
||||||
for _, token := range tokens {
|
|
||||||
piece := m.TokenToPiece(token)
|
|
||||||
if piece == "" {
|
|
||||||
return "", fmt.Errorf("failed to convert token %d to piece", token)
|
|
||||||
}
|
|
||||||
text += piece
|
|
||||||
}
|
|
||||||
return text, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type ModelParams struct {
|
type ModelParams struct {
|
||||||
NumGpuLayers int
|
NumGpuLayers int
|
||||||
MainGpu int
|
MainGpu int
|
||||||
|
@ -104,6 +104,7 @@ type NewSequenceParams struct {
|
|||||||
numKeep int
|
numKeep int
|
||||||
samplingParams *llama.SamplingParams
|
samplingParams *llama.SamplingParams
|
||||||
embedding bool
|
embedding bool
|
||||||
|
logprobs int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
|
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||||
@ -164,6 +165,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
|||||||
embeddingOnly: params.embedding,
|
embeddingOnly: params.embedding,
|
||||||
stop: params.stop,
|
stop: params.stop,
|
||||||
numKeep: params.numKeep,
|
numKeep: params.numKeep,
|
||||||
|
logprobs: params.logprobs,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -285,37 +287,34 @@ func flushPending(seq *Sequence) bool {
|
|||||||
if len(seq.pendingResponses) == 0 {
|
if len(seq.pendingResponses) == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
content := ""
|
resps := []CompletionResponse{}
|
||||||
for _, resp := range seq.pendingResponses {
|
for _, resp := range seq.pendingResponses {
|
||||||
content += resp.Content
|
resps = append(resps, resp)
|
||||||
}
|
}
|
||||||
seq.pendingResponses = []CompletionResponse{}
|
seq.pendingResponses = []CompletionResponse{}
|
||||||
|
|
||||||
// Check if there are any partial UTF-8 characters remaining.
|
// TODO: figure out this result logic
|
||||||
// We already check and queue as we are generating but some may
|
result := false
|
||||||
// still make it here:
|
for _, resp := range resps {
|
||||||
// - Sequence is ending, e.g. generation limit has been hit
|
// Check if there are any partial UTF-8 characters remaining.
|
||||||
// - Invalid characters in the middle of a string
|
// We already check and queue as we are generating but some may
|
||||||
// This is a stricter check to ensure we never output invalid Unicode.
|
// still make it here:
|
||||||
for !utf8.ValidString(content) {
|
// - Sequence is ending, e.g. generation limit has been hit
|
||||||
content = content[:len(content)-1]
|
// - Invalid characters in the middle of a string
|
||||||
|
// This is a stricter check to ensure we never output invalid Unicode.
|
||||||
|
for !utf8.ValidString(resp.Content) {
|
||||||
|
resp.Content = resp.Content[:len(resp.Content)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case seq.responses <- resp:
|
||||||
|
result = true
|
||||||
|
case <-seq.quit:
|
||||||
|
result = false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add logits if requested and available
|
return result
|
||||||
wantLogits := true
|
|
||||||
if wantLogits && seq.logits != nil {
|
|
||||||
// resp.Logits = seq.logits
|
|
||||||
seq.logits = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case seq.responses <- CompletionResponse{
|
|
||||||
Content: content,
|
|
||||||
}:
|
|
||||||
return true
|
|
||||||
case <-seq.quit:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) removeSequence(seqIndex int, reason string) {
|
func (s *Server) removeSequence(seqIndex int, reason string) {
|
||||||
@ -371,10 +370,11 @@ func (s *Server) run(ctx context.Context) {
|
|||||||
|
|
||||||
// TokenProbs represents probability information for a token
|
// TokenProbs represents probability information for a token
|
||||||
type TokenProbs struct {
|
type TokenProbs struct {
|
||||||
TokenID int
|
TokenID int `json:"id"`
|
||||||
Logit float32
|
Logit float32 `json:"logit"`
|
||||||
Prob float32
|
Prob float32 `json:"prob"`
|
||||||
LogProb float32
|
LogProb float32 `json:"logprob"`
|
||||||
|
Token string `json:"token"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// probs returns sorted token probabilities for a specific token index
|
// probs returns sorted token probabilities for a specific token index
|
||||||
@ -553,9 +553,17 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||||||
|
|
||||||
seq.numPredicted++
|
seq.numPredicted++
|
||||||
|
|
||||||
|
resp := CompletionResponse{Content: piece}
|
||||||
|
|
||||||
if seq.logprobs > 0 {
|
if seq.logprobs > 0 {
|
||||||
// TODO: return selected token in logprobs always
|
// TODO: return selected token in logprobs always
|
||||||
// probs := s.probs(seq)
|
resp.LogProbs = s.probs(seq)
|
||||||
|
// TODO: fix this logprobs limit
|
||||||
|
resp.LogProbs = resp.LogProbs[:min(len(resp.LogProbs), seq.logprobs)]
|
||||||
|
for i := range resp.LogProbs {
|
||||||
|
// decode the token id to a piece
|
||||||
|
resp.LogProbs[i].Token = s.model.TokenToPiece(resp.LogProbs[i].TokenID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// if it's an end of sequence token, break
|
// if it's an end of sequence token, break
|
||||||
@ -571,7 +579,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||||||
seq.inputs = []input{{token: token}}
|
seq.inputs = []input{{token: token}}
|
||||||
|
|
||||||
// TODO: add probs here
|
// TODO: add probs here
|
||||||
seq.pendingResponses = append(seq.pendingResponses, CompletionResponse{Content: piece})
|
seq.pendingResponses = append(seq.pendingResponses, resp)
|
||||||
var sequence string
|
var sequence string
|
||||||
for _, r := range seq.pendingResponses {
|
for _, r := range seq.pendingResponses {
|
||||||
sequence += r.Content
|
sequence += r.Content
|
||||||
@ -580,10 +588,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||||||
if ok, stop := findStop(sequence, seq.stop); ok {
|
if ok, stop := findStop(sequence, seq.stop); ok {
|
||||||
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
||||||
|
|
||||||
|
// TODO: fix this stop sequence caching
|
||||||
var tokenTruncated bool
|
var tokenTruncated bool
|
||||||
origLen := len(seq.pendingResponses)
|
origLen := len(sequence)
|
||||||
seq.pendingResponses, tokenTruncated = truncateStop(seq.pendingResponses, stop)
|
sequence, tokenTruncated = truncateStop(sequence, stop)
|
||||||
newLen := len(seq.pendingResponses)
|
newLen := len(sequence)
|
||||||
|
|
||||||
// Update the cache based on the tokens that will be returned:
|
// Update the cache based on the tokens that will be returned:
|
||||||
// - We have 1 token more than is currently in the cache because
|
// - We have 1 token more than is currently in the cache because
|
||||||
@ -654,6 +663,7 @@ type CompletionRequest struct {
|
|||||||
Images []ImageData `json:"image_data"`
|
Images []ImageData `json:"image_data"`
|
||||||
Grammar string `json:"grammar"`
|
Grammar string `json:"grammar"`
|
||||||
CachePrompt bool `json:"cache_prompt"`
|
CachePrompt bool `json:"cache_prompt"`
|
||||||
|
Logprobs int `json:"logprobs,omitempty"`
|
||||||
|
|
||||||
Options
|
Options
|
||||||
}
|
}
|
||||||
@ -669,8 +679,10 @@ type CompletionResponse struct {
|
|||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
Stop bool `json:"stop"`
|
Stop bool `json:"stop"`
|
||||||
|
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
Prompt string `json:"prompt,omitempty"`
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
LogProbs []TokenProbs `json:"logprobs,omitempty"`
|
||||||
|
|
||||||
StoppedLimit bool `json:"stopped_limit,omitempty"`
|
StoppedLimit bool `json:"stopped_limit,omitempty"`
|
||||||
PredictedN int `json:"predicted_n,omitempty"`
|
PredictedN int `json:"predicted_n,omitempty"`
|
||||||
PredictedMS float64 `json:"predicted_ms,omitempty"`
|
PredictedMS float64 `json:"predicted_ms,omitempty"`
|
||||||
@ -688,10 +700,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the headers to indicate streaming
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.Header().Set("Transfer-Encoding", "chunked")
|
|
||||||
|
|
||||||
flusher, ok := w.(http.Flusher)
|
flusher, ok := w.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
||||||
@ -720,6 +728,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
numKeep: req.NumKeep,
|
numKeep: req.NumKeep,
|
||||||
samplingParams: &samplingParams,
|
samplingParams: &samplingParams,
|
||||||
embedding: false,
|
embedding: false,
|
||||||
|
logprobs: req.Logprobs,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||||
@ -769,6 +778,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
case resp, ok := <-seq.responses:
|
case resp, ok := <-seq.responses:
|
||||||
if ok {
|
if ok {
|
||||||
|
fmt.Println("response", resp)
|
||||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
close(seq.quit)
|
close(seq.quit)
|
||||||
|
@ -26,46 +26,15 @@ func containsStopSuffix(sequence string, stops []string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// truncateStop removes the provided stop string from pieces,
|
// truncateStop removes the provided stop string from sequence,
|
||||||
// returning the partial pieces with stop removed, including truncating
|
// returning both the truncated sequence and a bool indicating if truncation occurred
|
||||||
// the last piece if required (and signalling if this was the case)
|
func truncateStop(sequence string, stop string) (string, bool) {
|
||||||
func truncateStop(pieces []CompletionResponse, stop string) ([]CompletionResponse, bool) {
|
index := strings.Index(sequence, stop)
|
||||||
// Build complete string and find stop position
|
if index == -1 {
|
||||||
var completeStr string
|
return sequence, false
|
||||||
for _, piece := range pieces {
|
|
||||||
completeStr += piece.Content
|
|
||||||
}
|
}
|
||||||
|
|
||||||
stopStart := strings.Index(completeStr, stop)
|
return sequence[:index], true
|
||||||
if stopStart == -1 {
|
|
||||||
return pieces, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build result up to stop position
|
|
||||||
result := make([]CompletionResponse, 0)
|
|
||||||
accumulated := 0
|
|
||||||
|
|
||||||
truncated := false
|
|
||||||
for _, piece := range pieces {
|
|
||||||
if accumulated+len(piece.Content) <= stopStart {
|
|
||||||
result = append(result, piece)
|
|
||||||
accumulated += len(piece.Content)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if accumulated < stopStart {
|
|
||||||
truncPiece := piece
|
|
||||||
truncPiece.Content = piece.Content[:stopStart-accumulated]
|
|
||||||
if len(truncPiece.Content) > 0 {
|
|
||||||
result = append(result, truncPiece)
|
|
||||||
truncated = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// Signal if we had to truncate the last piece
|
|
||||||
return result, truncated
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func incompleteUnicode(token string) bool {
|
func incompleteUnicode(token string) bool {
|
||||||
|
@ -1,90 +1,60 @@
|
|||||||
package runner
|
package runner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestTruncateStop(t *testing.T) {
|
func TestTruncateStop(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
pieces []CompletionResponse
|
sequence string
|
||||||
stop string
|
stop string
|
||||||
expected []CompletionResponse
|
expected string
|
||||||
expectedTrunc bool
|
expectedTrunc bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Single word",
|
name: "Single word",
|
||||||
pieces: []CompletionResponse{
|
sequence: "helloworld",
|
||||||
{Content: "hello"},
|
stop: "world",
|
||||||
{Content: "world"},
|
expected: "hello",
|
||||||
},
|
expectedTrunc: true,
|
||||||
stop: "world",
|
},
|
||||||
expected: []CompletionResponse{
|
{
|
||||||
{Content: "hello"},
|
name: "Partial",
|
||||||
},
|
sequence: "hellowor",
|
||||||
|
stop: "or",
|
||||||
|
expected: "hellow",
|
||||||
|
expectedTrunc: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Suffix",
|
||||||
|
sequence: "Hello there!",
|
||||||
|
stop: "!",
|
||||||
|
expected: "Hello there",
|
||||||
|
expectedTrunc: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Middle",
|
||||||
|
sequence: "hello wor",
|
||||||
|
stop: "llo w",
|
||||||
|
expected: "he",
|
||||||
|
expectedTrunc: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No stop found",
|
||||||
|
sequence: "hello world",
|
||||||
|
stop: "xyz",
|
||||||
|
expected: "hello world",
|
||||||
expectedTrunc: false,
|
expectedTrunc: false,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "Partial",
|
|
||||||
pieces: []CompletionResponse{
|
|
||||||
{Content: "hello"},
|
|
||||||
{Content: "wor"},
|
|
||||||
},
|
|
||||||
stop: "or",
|
|
||||||
expected: []CompletionResponse{
|
|
||||||
{Content: "hello"},
|
|
||||||
{Content: "w"},
|
|
||||||
},
|
|
||||||
expectedTrunc: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Suffix",
|
|
||||||
pieces: []CompletionResponse{
|
|
||||||
{Content: "Hello"},
|
|
||||||
{Content: " there"},
|
|
||||||
{Content: "!"},
|
|
||||||
},
|
|
||||||
stop: "!",
|
|
||||||
expected: []CompletionResponse{
|
|
||||||
{Content: "Hello"},
|
|
||||||
{Content: " there"},
|
|
||||||
},
|
|
||||||
expectedTrunc: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Suffix partial",
|
|
||||||
pieces: []CompletionResponse{
|
|
||||||
{Content: "Hello"},
|
|
||||||
{Content: " the"},
|
|
||||||
{Content: "re!"},
|
|
||||||
},
|
|
||||||
stop: "there!",
|
|
||||||
expected: []CompletionResponse{
|
|
||||||
{Content: "Hello"},
|
|
||||||
{Content: " "},
|
|
||||||
},
|
|
||||||
expectedTrunc: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Middle",
|
|
||||||
pieces: []CompletionResponse{
|
|
||||||
{Content: "hello"},
|
|
||||||
{Content: " wor"},
|
|
||||||
},
|
|
||||||
stop: "llo w",
|
|
||||||
expected: []CompletionResponse{
|
|
||||||
{Content: "he"},
|
|
||||||
},
|
|
||||||
expectedTrunc: true,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result, resultTrunc := truncateStop(tt.pieces, tt.stop)
|
result, truncated := truncateStop(tt.sequence, tt.stop)
|
||||||
if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
|
if result != tt.expected || truncated != tt.expectedTrunc {
|
||||||
t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc)
|
t.Errorf("truncateStop(%q, %q): have %q (%v); want %q (%v)",
|
||||||
|
tt.sequence, tt.stop, result, truncated, tt.expected, tt.expectedTrunc)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -644,12 +644,22 @@ type ImageData struct {
|
|||||||
AspectRatioID int `json:"aspect_ratio_id"`
|
AspectRatioID int `json:"aspect_ratio_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TokenProbs represents probability information for a token
|
||||||
|
type TokenProbs struct {
|
||||||
|
TokenID int `json:"id"`
|
||||||
|
Logit float32 `json:"logit"`
|
||||||
|
Prob float32 `json:"prob"`
|
||||||
|
LogProb float32 `json:"logprob"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
}
|
||||||
|
|
||||||
type completion struct {
|
type completion struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
Stop bool `json:"stop"`
|
Stop bool `json:"stop"`
|
||||||
StoppedLimit bool `json:"stopped_limit"`
|
StoppedLimit bool `json:"stopped_limit"`
|
||||||
|
LogProbs []TokenProbs `json:"logprobs"`
|
||||||
|
|
||||||
Timings struct {
|
Timings struct {
|
||||||
PredictedN int `json:"predicted_n"`
|
PredictedN int `json:"predicted_n"`
|
||||||
@ -660,14 +670,16 @@ type completion struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CompletionRequest struct {
|
type CompletionRequest struct {
|
||||||
Prompt string
|
Prompt string
|
||||||
Format json.RawMessage
|
Format json.RawMessage
|
||||||
Images []ImageData
|
Images []ImageData
|
||||||
Options *api.Options
|
LogProbs int
|
||||||
|
Options *api.Options
|
||||||
}
|
}
|
||||||
|
|
||||||
type CompletionResponse struct {
|
type CompletionResponse struct {
|
||||||
Content string
|
Content string
|
||||||
|
LogProbs []TokenProbs
|
||||||
DoneReason string
|
DoneReason string
|
||||||
Done bool
|
Done bool
|
||||||
PromptEvalCount int
|
PromptEvalCount int
|
||||||
@ -698,9 +710,12 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
"seed": req.Options.Seed,
|
"seed": req.Options.Seed,
|
||||||
"stop": req.Options.Stop,
|
"stop": req.Options.Stop,
|
||||||
"image_data": req.Images,
|
"image_data": req.Images,
|
||||||
|
"logprobs": req.LogProbs,
|
||||||
"cache_prompt": true,
|
"cache_prompt": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println("completion request:", request)
|
||||||
|
|
||||||
if len(req.Format) > 0 {
|
if len(req.Format) > 0 {
|
||||||
switch string(req.Format) {
|
switch string(req.Format) {
|
||||||
case `null`, `""`:
|
case `null`, `""`:
|
||||||
@ -796,7 +811,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// slog.Debug("got line", "line", string(line))
|
|
||||||
evt, ok := bytes.CutPrefix(line, []byte("data: "))
|
evt, ok := bytes.CutPrefix(line, []byte("data: "))
|
||||||
if !ok {
|
if !ok {
|
||||||
evt = line
|
evt = line
|
||||||
@ -822,7 +836,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
|
|
||||||
if c.Content != "" {
|
if c.Content != "" {
|
||||||
fn(CompletionResponse{
|
fn(CompletionResponse{
|
||||||
Content: c.Content,
|
Content: c.Content,
|
||||||
|
LogProbs: c.LogProbs,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -839,6 +854,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
|
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
|
||||||
EvalCount: c.Timings.PredictedN,
|
EvalCount: c.Timings.PredictedN,
|
||||||
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
|
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
|
||||||
|
LogProbs: c.LogProbs,
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -293,11 +293,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Images: images,
|
Images: images,
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
Options: opts,
|
LogProbs: req.LogProbs,
|
||||||
|
Options: opts,
|
||||||
}, func(cr llm.CompletionResponse) {
|
}, func(cr llm.CompletionResponse) {
|
||||||
|
fmt.Printf("banana: %#v\n", cr)
|
||||||
res := api.GenerateResponse{
|
res := api.GenerateResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
@ -311,6 +313,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
EvalDuration: cr.EvalDuration,
|
EvalDuration: cr.EvalDuration,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
for _, p := range cr.LogProbs {
|
||||||
|
res.LogProbs = append(res.LogProbs, api.TokenProbs{
|
||||||
|
TokenID: p.TokenID,
|
||||||
|
LogProb: p.LogProb,
|
||||||
|
Token: p.Token,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
if _, err := sb.WriteString(cr.Content); err != nil {
|
if _, err := sb.WriteString(cr.Content); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
@ -1466,10 +1475,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
var toolCallIndex int = 0
|
var toolCallIndex int = 0
|
||||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Images: images,
|
Images: images,
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
Options: opts,
|
LogProbs: req.LogProbs,
|
||||||
|
Options: opts,
|
||||||
}, func(r llm.CompletionResponse) {
|
}, func(r llm.CompletionResponse) {
|
||||||
res := api.ChatResponse{
|
res := api.ChatResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
@ -1484,6 +1494,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
EvalDuration: r.EvalDuration,
|
EvalDuration: r.EvalDuration,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
for _, p := range r.LogProbs {
|
||||||
|
res.LogProbs = append(res.LogProbs, api.TokenProbs{
|
||||||
|
TokenID: p.TokenID,
|
||||||
|
LogProb: p.LogProb,
|
||||||
|
Token: p.Token,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
if r.Done {
|
if r.Done {
|
||||||
res.TotalDuration = time.Since(checkpointStart)
|
res.TotalDuration = time.Since(checkpointStart)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user