Compare commits
5 Commits
main
...
brucemacd/
Author | SHA1 | Date | |
---|---|---|---|
![]() |
b88489a87e | ||
![]() |
fdbb0b5cfe | ||
![]() |
64f95067ba | ||
![]() |
6dfcdec2da | ||
![]() |
7d16ec8fe8 |
21
api/types.go
21
api/types.go
@ -77,6 +77,8 @@ type GenerateRequest struct {
|
||||
// request, for multimodal models.
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
|
||||
LogProbs int `json:"logprobs,omitempty"`
|
||||
|
||||
// Options lists model-specific options. For example, temperature can be
|
||||
// set through this field, if the model supports it.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
@ -103,6 +105,8 @@ type ChatRequest struct {
|
||||
// Tools is an optional list of tools the model has access to.
|
||||
Tools `json:"tools,omitempty"`
|
||||
|
||||
LogProbs int `json:"logprobs,omitempty"`
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
@ -182,13 +186,20 @@ func (t *ToolFunction) String() string {
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
type TokenProbs struct {
|
||||
TokenID int `json:"id"`
|
||||
LogProb float32 `json:"logprob"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
||||
// similar to [GenerateResponse].
|
||||
type ChatResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Message Message `json:"message"`
|
||||
DoneReason string `json:"done_reason,omitempty"`
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Message Message `json:"message"`
|
||||
DoneReason string `json:"done_reason,omitempty"`
|
||||
LogProbs []TokenProbs `json:"logprobs,omitempty"`
|
||||
|
||||
Done bool `json:"done"`
|
||||
|
||||
@ -452,6 +463,8 @@ type GenerateResponse struct {
|
||||
// can be sent in the next request to keep a conversational memory.
|
||||
Context []int `json:"context,omitempty"`
|
||||
|
||||
LogProbs []TokenProbs `json:"logprobs,omitempty"`
|
||||
|
||||
Metrics
|
||||
}
|
||||
|
||||
|
@ -50,7 +50,7 @@ import (
|
||||
_ "github.com/ollama/ollama/llama/llama.cpp/common"
|
||||
_ "github.com/ollama/ollama/llama/llama.cpp/examples/llava"
|
||||
_ "github.com/ollama/ollama/llama/llama.cpp/src"
|
||||
"github.com/ollama/ollama/ml/backend/ggml/ggml/src"
|
||||
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
|
||||
)
|
||||
|
||||
func BackendInit() {
|
||||
@ -220,6 +220,19 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 {
|
||||
return embeddings
|
||||
}
|
||||
|
||||
// GetLogits returns the logits from the last decode operation.
|
||||
// The returned slice has length equal to the vocabulary size.
|
||||
func (c *Context) GetLogits() []float32 {
|
||||
logits := unsafe.Pointer(C.llama_get_logits(c.c))
|
||||
if logits == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the number of vocabulary tokens to determine array size
|
||||
vocabSize := c.Model().NumVocab()
|
||||
return unsafe.Slice((*float32)(logits), vocabSize)
|
||||
}
|
||||
|
||||
type ModelParams struct {
|
||||
NumGpuLayers int
|
||||
MainGpu int
|
||||
|
@ -8,12 +8,14 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -48,8 +50,9 @@ type Sequence struct {
|
||||
// inputs that have been added to a batch but not yet submitted to Decode
|
||||
pendingInputs []input
|
||||
|
||||
// TODO: update this comment
|
||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||
pendingResponses []string
|
||||
pendingResponses []CompletionResponse
|
||||
|
||||
// input cache being used by this sequence
|
||||
cache *InputCacheSlot
|
||||
@ -59,7 +62,7 @@ type Sequence struct {
|
||||
crossAttention bool
|
||||
|
||||
// channel to send responses over
|
||||
responses chan string
|
||||
responses chan CompletionResponse
|
||||
|
||||
// channel to stop decoding (such as if the remote connection is closed)
|
||||
quit chan bool
|
||||
@ -83,6 +86,11 @@ type Sequence struct {
|
||||
|
||||
doneReason string
|
||||
|
||||
logits []float32
|
||||
|
||||
// number of logprobs to return with the completion response
|
||||
logprobs int
|
||||
|
||||
// Metrics
|
||||
startProcessingTime time.Time
|
||||
startGenerationTime time.Time
|
||||
@ -96,6 +104,7 @@ type NewSequenceParams struct {
|
||||
numKeep int
|
||||
samplingParams *llama.SamplingParams
|
||||
embedding bool
|
||||
logprobs int
|
||||
}
|
||||
|
||||
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
@ -148,14 +157,15 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||
numPromptInputs: len(inputs),
|
||||
startProcessingTime: startTime,
|
||||
numPredict: params.numPredict,
|
||||
pendingResponses: make([]string, 0),
|
||||
responses: make(chan string, 100),
|
||||
pendingResponses: make([]CompletionResponse, 0),
|
||||
responses: make(chan CompletionResponse, 100),
|
||||
quit: make(chan bool, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
samplingCtx: sc,
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
logprobs: params.logprobs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -274,29 +284,37 @@ func (s *Server) allNil() bool {
|
||||
}
|
||||
|
||||
func flushPending(seq *Sequence) bool {
|
||||
joined := strings.Join(seq.pendingResponses, "")
|
||||
seq.pendingResponses = []string{}
|
||||
|
||||
// Check if there are any partial UTF-8 characters remaining.
|
||||
// We already check and queue as we are generating but some may
|
||||
// still make it here:
|
||||
// - Sequence is ending, e.g. generation limit has been hit
|
||||
// - Invalid characters in the middle of a string
|
||||
// This is a stricter check to ensure we never output invalid Unicode.
|
||||
for !utf8.ValidString(joined) {
|
||||
joined = joined[:len(joined)-1]
|
||||
}
|
||||
|
||||
if len(joined) == 0 {
|
||||
if len(seq.pendingResponses) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
select {
|
||||
case seq.responses <- joined:
|
||||
return true
|
||||
case <-seq.quit:
|
||||
return false
|
||||
resps := []CompletionResponse{}
|
||||
for _, resp := range seq.pendingResponses {
|
||||
resps = append(resps, resp)
|
||||
}
|
||||
seq.pendingResponses = []CompletionResponse{}
|
||||
|
||||
// TODO: figure out this result logic
|
||||
result := false
|
||||
for _, resp := range resps {
|
||||
// Check if there are any partial UTF-8 characters remaining.
|
||||
// We already check and queue as we are generating but some may
|
||||
// still make it here:
|
||||
// - Sequence is ending, e.g. generation limit has been hit
|
||||
// - Invalid characters in the middle of a string
|
||||
// This is a stricter check to ensure we never output invalid Unicode.
|
||||
for !utf8.ValidString(resp.Content) {
|
||||
resp.Content = resp.Content[:len(resp.Content)-1]
|
||||
}
|
||||
|
||||
select {
|
||||
case seq.responses <- resp:
|
||||
result = true
|
||||
case <-seq.quit:
|
||||
result = false
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *Server) removeSequence(seqIndex int, reason string) {
|
||||
@ -350,6 +368,63 @@ func (s *Server) run(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// TokenProbs represents probability information for a token
|
||||
type TokenProbs struct {
|
||||
TokenID int `json:"id"`
|
||||
Logit float32 `json:"logit"`
|
||||
Prob float32 `json:"prob"`
|
||||
LogProb float32 `json:"logprob"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
// probs returns sorted token probabilities for a specific token index
|
||||
func probs(logits []float32, vocabSize int) []TokenProbs {
|
||||
probs := make([]TokenProbs, vocabSize)
|
||||
|
||||
// Initialize token data with logits
|
||||
for i := 0; i < vocabSize; i++ {
|
||||
probs[i] = TokenProbs{
|
||||
TokenID: i,
|
||||
Logit: logits[i],
|
||||
}
|
||||
}
|
||||
|
||||
// Sort tokens by logits in descending order
|
||||
sort.Slice(probs, func(i, j int) bool {
|
||||
return probs[i].Logit > probs[j].Logit
|
||||
})
|
||||
|
||||
// Apply softmax
|
||||
maxLogit := probs[0].Logit
|
||||
var sum float32 = 0.0
|
||||
|
||||
for i := range probs {
|
||||
p := float32(math.Exp(float64(probs[i].Logit - maxLogit)))
|
||||
probs[i].Prob = p
|
||||
sum += p
|
||||
}
|
||||
|
||||
// Normalize probabilities and calculate log probs
|
||||
for i := range probs {
|
||||
prob := probs[i].Prob / sum
|
||||
probs[i].Prob = prob
|
||||
probs[i].LogProb = float32(math.Log(float64(prob)))
|
||||
}
|
||||
|
||||
return probs
|
||||
}
|
||||
|
||||
// probs returns sorted token probabilities for a specific token index
|
||||
func (s *Server) probs(seq *Sequence) []TokenProbs {
|
||||
// Get logits for the specific token index
|
||||
logits := s.lc.GetLogits()
|
||||
seq.logits = make([]float32, len(logits))
|
||||
copy(seq.logits, logits)
|
||||
|
||||
vocabSize := s.model.NumVocab()
|
||||
return probs(logits, vocabSize)
|
||||
}
|
||||
|
||||
// TODO (jmorganca): processBatch should be simplified, removing:
|
||||
// * sampling
|
||||
// * stop token checking
|
||||
@ -483,6 +558,19 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
|
||||
seq.numPredicted++
|
||||
|
||||
resp := CompletionResponse{Content: piece}
|
||||
|
||||
if seq.logprobs > 0 {
|
||||
// TODO: return selected token in logprobs always
|
||||
resp.LogProbs = s.probs(seq)
|
||||
// TODO: fix this logprobs limit
|
||||
resp.LogProbs = resp.LogProbs[:min(len(resp.LogProbs), seq.logprobs)]
|
||||
for i := range resp.LogProbs {
|
||||
// decode the token id to a piece
|
||||
resp.LogProbs[i].Token = s.model.TokenToPiece(resp.LogProbs[i].TokenID)
|
||||
}
|
||||
}
|
||||
|
||||
// if it's an end of sequence token, break
|
||||
if s.model.TokenIsEog(token) {
|
||||
// TODO (jmorganca): we should send this back
|
||||
@ -495,16 +583,21 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
|
||||
seq.inputs = []input{{token: token}}
|
||||
|
||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||
sequence := strings.Join(seq.pendingResponses, "")
|
||||
// TODO: add probs here
|
||||
seq.pendingResponses = append(seq.pendingResponses, resp)
|
||||
var sequence string
|
||||
for _, r := range seq.pendingResponses {
|
||||
sequence += r.Content
|
||||
}
|
||||
|
||||
if ok, stop := findStop(sequence, seq.stop); ok {
|
||||
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
||||
|
||||
// TODO: fix this stop sequence caching
|
||||
var tokenTruncated bool
|
||||
origLen := len(seq.pendingResponses)
|
||||
seq.pendingResponses, tokenTruncated = truncateStop(seq.pendingResponses, stop)
|
||||
newLen := len(seq.pendingResponses)
|
||||
origLen := len(sequence)
|
||||
sequence, tokenTruncated = truncateStop(sequence, stop)
|
||||
newLen := len(sequence)
|
||||
|
||||
// Update the cache based on the tokens that will be returned:
|
||||
// - We have 1 token more than is currently in the cache because
|
||||
@ -575,6 +668,7 @@ type CompletionRequest struct {
|
||||
Images []ImageData `json:"image_data"`
|
||||
Grammar string `json:"grammar"`
|
||||
CachePrompt bool `json:"cache_prompt"`
|
||||
Logprobs int `json:"logprobs,omitempty"`
|
||||
|
||||
Options
|
||||
}
|
||||
@ -590,8 +684,10 @@ type CompletionResponse struct {
|
||||
Content string `json:"content"`
|
||||
Stop bool `json:"stop"`
|
||||
|
||||
Model string `json:"model,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
LogProbs []TokenProbs `json:"logprobs,omitempty"`
|
||||
|
||||
StoppedLimit bool `json:"stopped_limit,omitempty"`
|
||||
PredictedN int `json:"predicted_n,omitempty"`
|
||||
PredictedMS float64 `json:"predicted_ms,omitempty"`
|
||||
@ -609,10 +705,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Set the headers to indicate streaming
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
||||
@ -641,6 +733,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
numKeep: req.NumKeep,
|
||||
samplingParams: &samplingParams,
|
||||
embedding: false,
|
||||
logprobs: req.Logprobs,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||
@ -688,11 +781,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
case <-r.Context().Done():
|
||||
close(seq.quit)
|
||||
return
|
||||
case content, ok := <-seq.responses:
|
||||
case resp, ok := <-seq.responses:
|
||||
if ok {
|
||||
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
||||
Content: content,
|
||||
}); err != nil {
|
||||
fmt.Println("response", resp)
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
close(seq.quit)
|
||||
return
|
||||
|
58
llama/runner/runner_test.go
Normal file
58
llama/runner/runner_test.go
Normal file
@ -0,0 +1,58 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestProbs(t *testing.T) {
|
||||
// Input test data
|
||||
logits := []float32{1.0, 2.0, 0.5, -1.0}
|
||||
vocabSize := 4
|
||||
want := []TokenProbs{
|
||||
{TokenID: 1, Logit: 2.0}, // Highest logit
|
||||
{TokenID: 0, Logit: 1.0}, // Second highest
|
||||
{TokenID: 2, Logit: 0.5}, // Third
|
||||
{TokenID: 3, Logit: -1.0}, // Lowest
|
||||
}
|
||||
|
||||
got := probs(logits, vocabSize)
|
||||
|
||||
// Test 1: Check sorting order
|
||||
for i := 0; i < len(got)-1; i++ {
|
||||
if got[i].Logit < got[i+1].Logit {
|
||||
t.Errorf("probs not properly sorted: logit at pos %d (%f) < logit at pos %d (%f)",
|
||||
i, got[i].Logit, i+1, got[i+1].Logit)
|
||||
}
|
||||
}
|
||||
|
||||
// Test 2: Check probability normalization
|
||||
var sum float32
|
||||
for _, p := range got {
|
||||
sum += p.Prob
|
||||
}
|
||||
if math.Abs(float64(sum-1.0)) > 1e-6 {
|
||||
t.Errorf("probabilities do not sum to 1: got %v", sum)
|
||||
}
|
||||
|
||||
// Test 3: Check token IDs match expected order
|
||||
for i, want := range want {
|
||||
if got[i].TokenID != want.TokenID {
|
||||
t.Errorf("wrong token ID at position %d: got %d, want %d",
|
||||
i, got[i].TokenID, want.TokenID)
|
||||
}
|
||||
if got[i].Logit != want.Logit {
|
||||
t.Errorf("wrong logit at position %d: got %f, want %f",
|
||||
i, got[i].Logit, want.Logit)
|
||||
}
|
||||
}
|
||||
|
||||
// Test 4: Check log probs are correctly calculated
|
||||
for i, p := range got {
|
||||
expectedLogProb := float32(math.Log(float64(p.Prob)))
|
||||
if math.Abs(float64(p.LogProb-expectedLogProb)) > 1e-6 {
|
||||
t.Errorf("wrong log prob at position %d: got %f, want %f",
|
||||
i, p.LogProb, expectedLogProb)
|
||||
}
|
||||
}
|
||||
}
|
@ -26,43 +26,15 @@ func containsStopSuffix(sequence string, stops []string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// truncateStop removes the provided stop string from pieces,
|
||||
// returning the partial pieces with stop removed, including truncating
|
||||
// the last piece if required (and signalling if this was the case)
|
||||
func truncateStop(pieces []string, stop string) ([]string, bool) {
|
||||
joined := strings.Join(pieces, "")
|
||||
|
||||
index := strings.Index(joined, stop)
|
||||
// truncateStop removes the provided stop string from sequence,
|
||||
// returning both the truncated sequence and a bool indicating if truncation occurred
|
||||
func truncateStop(sequence string, stop string) (string, bool) {
|
||||
index := strings.Index(sequence, stop)
|
||||
if index == -1 {
|
||||
return pieces, false
|
||||
return sequence, false
|
||||
}
|
||||
|
||||
joined = joined[:index]
|
||||
|
||||
// Split truncated string back into pieces of original lengths
|
||||
lengths := make([]int, len(pieces))
|
||||
for i, piece := range pieces {
|
||||
lengths[i] = len(piece)
|
||||
}
|
||||
|
||||
var result []string
|
||||
tokenTruncated := false
|
||||
start := 0
|
||||
for _, length := range lengths {
|
||||
if start >= len(joined) {
|
||||
break
|
||||
}
|
||||
|
||||
end := start + length
|
||||
if end > len(joined) {
|
||||
end = len(joined)
|
||||
tokenTruncated = true
|
||||
}
|
||||
result = append(result, joined[start:end])
|
||||
start = end
|
||||
}
|
||||
|
||||
return result, tokenTruncated
|
||||
return sequence[:index], true
|
||||
}
|
||||
|
||||
func incompleteUnicode(token string) bool {
|
||||
|
@ -1,60 +1,60 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTruncateStop(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pieces []string
|
||||
sequence string
|
||||
stop string
|
||||
expected []string
|
||||
expected string
|
||||
expectedTrunc bool
|
||||
}{
|
||||
{
|
||||
name: "Single word",
|
||||
pieces: []string{"hello", "world"},
|
||||
sequence: "helloworld",
|
||||
stop: "world",
|
||||
expected: []string{"hello"},
|
||||
expectedTrunc: false,
|
||||
expected: "hello",
|
||||
expectedTrunc: true,
|
||||
},
|
||||
{
|
||||
name: "Partial",
|
||||
pieces: []string{"hello", "wor"},
|
||||
sequence: "hellowor",
|
||||
stop: "or",
|
||||
expected: []string{"hello", "w"},
|
||||
expected: "hellow",
|
||||
expectedTrunc: true,
|
||||
},
|
||||
{
|
||||
name: "Suffix",
|
||||
pieces: []string{"Hello", " there", "!"},
|
||||
sequence: "Hello there!",
|
||||
stop: "!",
|
||||
expected: []string{"Hello", " there"},
|
||||
expectedTrunc: false,
|
||||
},
|
||||
{
|
||||
name: "Suffix partial",
|
||||
pieces: []string{"Hello", " the", "re!"},
|
||||
stop: "there!",
|
||||
expected: []string{"Hello", " "},
|
||||
expected: "Hello there",
|
||||
expectedTrunc: true,
|
||||
},
|
||||
{
|
||||
name: "Middle",
|
||||
pieces: []string{"hello", " wor"},
|
||||
sequence: "hello wor",
|
||||
stop: "llo w",
|
||||
expected: []string{"he"},
|
||||
expected: "he",
|
||||
expectedTrunc: true,
|
||||
},
|
||||
{
|
||||
name: "No stop found",
|
||||
sequence: "hello world",
|
||||
stop: "xyz",
|
||||
expected: "hello world",
|
||||
expectedTrunc: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, resultTrunc := truncateStop(tt.pieces, tt.stop)
|
||||
if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
|
||||
t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc)
|
||||
result, truncated := truncateStop(tt.sequence, tt.stop)
|
||||
if result != tt.expected || truncated != tt.expectedTrunc {
|
||||
t.Errorf("truncateStop(%q, %q): have %q (%v); want %q (%v)",
|
||||
tt.sequence, tt.stop, result, truncated, tt.expected, tt.expectedTrunc)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -644,12 +644,22 @@ type ImageData struct {
|
||||
AspectRatioID int `json:"aspect_ratio_id"`
|
||||
}
|
||||
|
||||
// TokenProbs represents probability information for a token
|
||||
type TokenProbs struct {
|
||||
TokenID int `json:"id"`
|
||||
Logit float32 `json:"logit"`
|
||||
Prob float32 `json:"prob"`
|
||||
LogProb float32 `json:"logprob"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type completion struct {
|
||||
Content string `json:"content"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Stop bool `json:"stop"`
|
||||
StoppedLimit bool `json:"stopped_limit"`
|
||||
Content string `json:"content"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Stop bool `json:"stop"`
|
||||
StoppedLimit bool `json:"stopped_limit"`
|
||||
LogProbs []TokenProbs `json:"logprobs"`
|
||||
|
||||
Timings struct {
|
||||
PredictedN int `json:"predicted_n"`
|
||||
@ -660,14 +670,16 @@ type completion struct {
|
||||
}
|
||||
|
||||
type CompletionRequest struct {
|
||||
Prompt string
|
||||
Format json.RawMessage
|
||||
Images []ImageData
|
||||
Options *api.Options
|
||||
Prompt string
|
||||
Format json.RawMessage
|
||||
Images []ImageData
|
||||
LogProbs int
|
||||
Options *api.Options
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Content string
|
||||
LogProbs []TokenProbs
|
||||
DoneReason string
|
||||
Done bool
|
||||
PromptEvalCount int
|
||||
@ -698,9 +710,12 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
"seed": req.Options.Seed,
|
||||
"stop": req.Options.Stop,
|
||||
"image_data": req.Images,
|
||||
"logprobs": req.LogProbs,
|
||||
"cache_prompt": true,
|
||||
}
|
||||
|
||||
fmt.Println("completion request:", request)
|
||||
|
||||
if len(req.Format) > 0 {
|
||||
switch string(req.Format) {
|
||||
case `null`, `""`:
|
||||
@ -796,7 +811,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
continue
|
||||
}
|
||||
|
||||
// slog.Debug("got line", "line", string(line))
|
||||
evt, ok := bytes.CutPrefix(line, []byte("data: "))
|
||||
if !ok {
|
||||
evt = line
|
||||
@ -822,7 +836,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
|
||||
if c.Content != "" {
|
||||
fn(CompletionResponse{
|
||||
Content: c.Content,
|
||||
Content: c.Content,
|
||||
LogProbs: c.LogProbs,
|
||||
})
|
||||
}
|
||||
|
||||
@ -839,6 +854,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
|
||||
EvalCount: c.Timings.PredictedN,
|
||||
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
|
||||
LogProbs: c.LogProbs,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
@ -293,11 +293,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
var sb strings.Builder
|
||||
defer close(ch)
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
LogProbs: req.LogProbs,
|
||||
Options: opts,
|
||||
}, func(cr llm.CompletionResponse) {
|
||||
fmt.Printf("banana: %#v\n", cr)
|
||||
res := api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
@ -311,6 +313,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
EvalDuration: cr.EvalDuration,
|
||||
},
|
||||
}
|
||||
for _, p := range cr.LogProbs {
|
||||
res.LogProbs = append(res.LogProbs, api.TokenProbs{
|
||||
TokenID: p.TokenID,
|
||||
LogProb: p.LogProb,
|
||||
Token: p.Token,
|
||||
})
|
||||
}
|
||||
|
||||
if _, err := sb.WriteString(cr.Content); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
@ -1466,10 +1475,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
var sb strings.Builder
|
||||
var toolCallIndex int = 0
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
LogProbs: req.LogProbs,
|
||||
Options: opts,
|
||||
}, func(r llm.CompletionResponse) {
|
||||
res := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
@ -1484,6 +1494,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
EvalDuration: r.EvalDuration,
|
||||
},
|
||||
}
|
||||
for _, p := range r.LogProbs {
|
||||
res.LogProbs = append(res.LogProbs, api.TokenProbs{
|
||||
TokenID: p.TokenID,
|
||||
LogProb: p.LogProb,
|
||||
Token: p.Token,
|
||||
})
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
|
Loading…
x
Reference in New Issue
Block a user