From fdbb0b5cfea9947b33f58565293206d6ef3a1b8a Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Thu, 13 Feb 2025 15:22:15 -0800 Subject: [PATCH] prototype --- api/types.go | 21 ++++++-- llama/llama.go | 12 ----- llama/runner/runner.go | 90 ++++++++++++++++++-------------- llama/runner/stop.go | 45 +++------------- llama/runner/stop_test.go | 106 ++++++++++++++------------------------ llm/server.go | 38 ++++++++++---- server/routes.go | 33 +++++++++--- 7 files changed, 164 insertions(+), 181 deletions(-) diff --git a/api/types.go b/api/types.go index f4c5b1058..c508891d6 100644 --- a/api/types.go +++ b/api/types.go @@ -77,6 +77,8 @@ type GenerateRequest struct { // request, for multimodal models. Images []ImageData `json:"images,omitempty"` + LogProbs int `json:"logprobs,omitempty"` + // Options lists model-specific options. For example, temperature can be // set through this field, if the model supports it. Options map[string]interface{} `json:"options"` @@ -103,6 +105,8 @@ type ChatRequest struct { // Tools is an optional list of tools the model has access to. Tools `json:"tools,omitempty"` + LogProbs int `json:"logprobs,omitempty"` + // Options lists model-specific options. Options map[string]interface{} `json:"options"` } @@ -182,13 +186,20 @@ func (t *ToolFunction) String() string { return string(bts) } +type TokenProbs struct { + TokenID int `json:"id"` + LogProb float32 `json:"logprob"` + Token string `json:"token"` +} + // ChatResponse is the response returned by [Client.Chat]. Its fields are // similar to [GenerateResponse]. type ChatResponse struct { - Model string `json:"model"` - CreatedAt time.Time `json:"created_at"` - Message Message `json:"message"` - DoneReason string `json:"done_reason,omitempty"` + Model string `json:"model"` + CreatedAt time.Time `json:"created_at"` + Message Message `json:"message"` + DoneReason string `json:"done_reason,omitempty"` + LogProbs []TokenProbs `json:"logprobs,omitempty"` Done bool `json:"done"` @@ -452,6 +463,8 @@ type GenerateResponse struct { // can be sent in the next request to keep a conversational memory. Context []int `json:"context,omitempty"` + LogProbs []TokenProbs `json:"logprobs,omitempty"` + Metrics } diff --git a/llama/llama.go b/llama/llama.go index b71aa51ac..90852c3d8 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -233,18 +233,6 @@ func (c *Context) GetLogits() []float32 { return unsafe.Slice((*float32)(logits), vocabSize) } -func (m *Model) Detokenize(tokens []int) (string, error) { - var text string - for _, token := range tokens { - piece := m.TokenToPiece(token) - if piece == "" { - return "", fmt.Errorf("failed to convert token %d to piece", token) - } - text += piece - } - return text, nil -} - type ModelParams struct { NumGpuLayers int MainGpu int diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 2f15f4bec..1e35fc19d 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -104,6 +104,7 @@ type NewSequenceParams struct { numKeep int samplingParams *llama.SamplingParams embedding bool + logprobs int } func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) { @@ -164,6 +165,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen embeddingOnly: params.embedding, stop: params.stop, numKeep: params.numKeep, + logprobs: params.logprobs, }, nil } @@ -285,37 +287,34 @@ func flushPending(seq *Sequence) bool { if len(seq.pendingResponses) == 0 { return true } - content := "" + resps := []CompletionResponse{} for _, resp := range seq.pendingResponses { - content += resp.Content + resps = append(resps, resp) } seq.pendingResponses = []CompletionResponse{} - // Check if there are any partial UTF-8 characters remaining. - // We already check and queue as we are generating but some may - // still make it here: - // - Sequence is ending, e.g. generation limit has been hit - // - Invalid characters in the middle of a string - // This is a stricter check to ensure we never output invalid Unicode. - for !utf8.ValidString(content) { - content = content[:len(content)-1] + // TODO: figure out this result logic + result := false + for _, resp := range resps { + // Check if there are any partial UTF-8 characters remaining. + // We already check and queue as we are generating but some may + // still make it here: + // - Sequence is ending, e.g. generation limit has been hit + // - Invalid characters in the middle of a string + // This is a stricter check to ensure we never output invalid Unicode. + for !utf8.ValidString(resp.Content) { + resp.Content = resp.Content[:len(resp.Content)-1] + } + + select { + case seq.responses <- resp: + result = true + case <-seq.quit: + result = false + } } - // Add logits if requested and available - wantLogits := true - if wantLogits && seq.logits != nil { - // resp.Logits = seq.logits - seq.logits = nil - } - - select { - case seq.responses <- CompletionResponse{ - Content: content, - }: - return true - case <-seq.quit: - return false - } + return result } func (s *Server) removeSequence(seqIndex int, reason string) { @@ -371,10 +370,11 @@ func (s *Server) run(ctx context.Context) { // TokenProbs represents probability information for a token type TokenProbs struct { - TokenID int - Logit float32 - Prob float32 - LogProb float32 + TokenID int `json:"id"` + Logit float32 `json:"logit"` + Prob float32 `json:"prob"` + LogProb float32 `json:"logprob"` + Token string `json:"token"` } // probs returns sorted token probabilities for a specific token index @@ -553,9 +553,17 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) seq.numPredicted++ + resp := CompletionResponse{Content: piece} + if seq.logprobs > 0 { // TODO: return selected token in logprobs always - // probs := s.probs(seq) + resp.LogProbs = s.probs(seq) + // TODO: fix this logprobs limit + resp.LogProbs = resp.LogProbs[:min(len(resp.LogProbs), seq.logprobs)] + for i := range resp.LogProbs { + // decode the token id to a piece + resp.LogProbs[i].Token = s.model.TokenToPiece(resp.LogProbs[i].TokenID) + } } // if it's an end of sequence token, break @@ -571,7 +579,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) seq.inputs = []input{{token: token}} // TODO: add probs here - seq.pendingResponses = append(seq.pendingResponses, CompletionResponse{Content: piece}) + seq.pendingResponses = append(seq.pendingResponses, resp) var sequence string for _, r := range seq.pendingResponses { sequence += r.Content @@ -580,10 +588,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) if ok, stop := findStop(sequence, seq.stop); ok { slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop) + // TODO: fix this stop sequence caching var tokenTruncated bool - origLen := len(seq.pendingResponses) - seq.pendingResponses, tokenTruncated = truncateStop(seq.pendingResponses, stop) - newLen := len(seq.pendingResponses) + origLen := len(sequence) + sequence, tokenTruncated = truncateStop(sequence, stop) + newLen := len(sequence) // Update the cache based on the tokens that will be returned: // - We have 1 token more than is currently in the cache because @@ -654,6 +663,7 @@ type CompletionRequest struct { Images []ImageData `json:"image_data"` Grammar string `json:"grammar"` CachePrompt bool `json:"cache_prompt"` + Logprobs int `json:"logprobs,omitempty"` Options } @@ -669,8 +679,10 @@ type CompletionResponse struct { Content string `json:"content"` Stop bool `json:"stop"` - Model string `json:"model,omitempty"` - Prompt string `json:"prompt,omitempty"` + Model string `json:"model,omitempty"` + Prompt string `json:"prompt,omitempty"` + LogProbs []TokenProbs `json:"logprobs,omitempty"` + StoppedLimit bool `json:"stopped_limit,omitempty"` PredictedN int `json:"predicted_n,omitempty"` PredictedMS float64 `json:"predicted_ms,omitempty"` @@ -688,10 +700,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } - // Set the headers to indicate streaming - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Transfer-Encoding", "chunked") - flusher, ok := w.(http.Flusher) if !ok { http.Error(w, "Streaming not supported", http.StatusInternalServerError) @@ -720,6 +728,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { numKeep: req.NumKeep, samplingParams: &samplingParams, embedding: false, + logprobs: req.Logprobs, }) if err != nil { http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) @@ -769,6 +778,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return case resp, ok := <-seq.responses: if ok { + fmt.Println("response", resp) if err := json.NewEncoder(w).Encode(&resp); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) close(seq.quit) diff --git a/llama/runner/stop.go b/llama/runner/stop.go index 76f5cb688..ff5de43c6 100644 --- a/llama/runner/stop.go +++ b/llama/runner/stop.go @@ -26,46 +26,15 @@ func containsStopSuffix(sequence string, stops []string) bool { return false } -// truncateStop removes the provided stop string from pieces, -// returning the partial pieces with stop removed, including truncating -// the last piece if required (and signalling if this was the case) -func truncateStop(pieces []CompletionResponse, stop string) ([]CompletionResponse, bool) { - // Build complete string and find stop position - var completeStr string - for _, piece := range pieces { - completeStr += piece.Content +// truncateStop removes the provided stop string from sequence, +// returning both the truncated sequence and a bool indicating if truncation occurred +func truncateStop(sequence string, stop string) (string, bool) { + index := strings.Index(sequence, stop) + if index == -1 { + return sequence, false } - stopStart := strings.Index(completeStr, stop) - if stopStart == -1 { - return pieces, false - } - - // Build result up to stop position - result := make([]CompletionResponse, 0) - accumulated := 0 - - truncated := false - for _, piece := range pieces { - if accumulated+len(piece.Content) <= stopStart { - result = append(result, piece) - accumulated += len(piece.Content) - continue - } - - if accumulated < stopStart { - truncPiece := piece - truncPiece.Content = piece.Content[:stopStart-accumulated] - if len(truncPiece.Content) > 0 { - result = append(result, truncPiece) - truncated = true - } - } - break - } - - // Signal if we had to truncate the last piece - return result, truncated + return sequence[:index], true } func incompleteUnicode(token string) bool { diff --git a/llama/runner/stop_test.go b/llama/runner/stop_test.go index 17174f400..52637ff5e 100644 --- a/llama/runner/stop_test.go +++ b/llama/runner/stop_test.go @@ -1,90 +1,60 @@ package runner import ( - "reflect" "testing" ) func TestTruncateStop(t *testing.T) { tests := []struct { name string - pieces []CompletionResponse + sequence string stop string - expected []CompletionResponse + expected string expectedTrunc bool }{ { - name: "Single word", - pieces: []CompletionResponse{ - {Content: "hello"}, - {Content: "world"}, - }, - stop: "world", - expected: []CompletionResponse{ - {Content: "hello"}, - }, + name: "Single word", + sequence: "helloworld", + stop: "world", + expected: "hello", + expectedTrunc: true, + }, + { + name: "Partial", + sequence: "hellowor", + stop: "or", + expected: "hellow", + expectedTrunc: true, + }, + { + name: "Suffix", + sequence: "Hello there!", + stop: "!", + expected: "Hello there", + expectedTrunc: true, + }, + { + name: "Middle", + sequence: "hello wor", + stop: "llo w", + expected: "he", + expectedTrunc: true, + }, + { + name: "No stop found", + sequence: "hello world", + stop: "xyz", + expected: "hello world", expectedTrunc: false, }, - { - name: "Partial", - pieces: []CompletionResponse{ - {Content: "hello"}, - {Content: "wor"}, - }, - stop: "or", - expected: []CompletionResponse{ - {Content: "hello"}, - {Content: "w"}, - }, - expectedTrunc: true, - }, - { - name: "Suffix", - pieces: []CompletionResponse{ - {Content: "Hello"}, - {Content: " there"}, - {Content: "!"}, - }, - stop: "!", - expected: []CompletionResponse{ - {Content: "Hello"}, - {Content: " there"}, - }, - expectedTrunc: false, - }, - { - name: "Suffix partial", - pieces: []CompletionResponse{ - {Content: "Hello"}, - {Content: " the"}, - {Content: "re!"}, - }, - stop: "there!", - expected: []CompletionResponse{ - {Content: "Hello"}, - {Content: " "}, - }, - expectedTrunc: true, - }, - { - name: "Middle", - pieces: []CompletionResponse{ - {Content: "hello"}, - {Content: " wor"}, - }, - stop: "llo w", - expected: []CompletionResponse{ - {Content: "he"}, - }, - expectedTrunc: true, - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, resultTrunc := truncateStop(tt.pieces, tt.stop) - if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc { - t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc) + result, truncated := truncateStop(tt.sequence, tt.stop) + if result != tt.expected || truncated != tt.expectedTrunc { + t.Errorf("truncateStop(%q, %q): have %q (%v); want %q (%v)", + tt.sequence, tt.stop, result, truncated, tt.expected, tt.expectedTrunc) } }) } diff --git a/llm/server.go b/llm/server.go index 881209b39..0f409c7cf 100644 --- a/llm/server.go +++ b/llm/server.go @@ -644,12 +644,22 @@ type ImageData struct { AspectRatioID int `json:"aspect_ratio_id"` } +// TokenProbs represents probability information for a token +type TokenProbs struct { + TokenID int `json:"id"` + Logit float32 `json:"logit"` + Prob float32 `json:"prob"` + LogProb float32 `json:"logprob"` + Token string `json:"token"` +} + type completion struct { - Content string `json:"content"` - Model string `json:"model"` - Prompt string `json:"prompt"` - Stop bool `json:"stop"` - StoppedLimit bool `json:"stopped_limit"` + Content string `json:"content"` + Model string `json:"model"` + Prompt string `json:"prompt"` + Stop bool `json:"stop"` + StoppedLimit bool `json:"stopped_limit"` + LogProbs []TokenProbs `json:"logprobs"` Timings struct { PredictedN int `json:"predicted_n"` @@ -660,14 +670,16 @@ type completion struct { } type CompletionRequest struct { - Prompt string - Format json.RawMessage - Images []ImageData - Options *api.Options + Prompt string + Format json.RawMessage + Images []ImageData + LogProbs int + Options *api.Options } type CompletionResponse struct { Content string + LogProbs []TokenProbs DoneReason string Done bool PromptEvalCount int @@ -698,9 +710,12 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu "seed": req.Options.Seed, "stop": req.Options.Stop, "image_data": req.Images, + "logprobs": req.LogProbs, "cache_prompt": true, } + fmt.Println("completion request:", request) + if len(req.Format) > 0 { switch string(req.Format) { case `null`, `""`: @@ -796,7 +811,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu continue } - // slog.Debug("got line", "line", string(line)) evt, ok := bytes.CutPrefix(line, []byte("data: ")) if !ok { evt = line @@ -822,7 +836,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu if c.Content != "" { fn(CompletionResponse{ - Content: c.Content, + Content: c.Content, + LogProbs: c.LogProbs, }) } @@ -839,6 +854,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu PromptEvalDuration: parseDurationMs(c.Timings.PromptMS), EvalCount: c.Timings.PredictedN, EvalDuration: parseDurationMs(c.Timings.PredictedMS), + LogProbs: c.LogProbs, }) return nil } diff --git a/server/routes.go b/server/routes.go index 5a4bb485c..dfedc1c3b 100644 --- a/server/routes.go +++ b/server/routes.go @@ -293,11 +293,13 @@ func (s *Server) GenerateHandler(c *gin.Context) { var sb strings.Builder defer close(ch) if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ - Prompt: prompt, - Images: images, - Format: req.Format, - Options: opts, + Prompt: prompt, + Images: images, + Format: req.Format, + LogProbs: req.LogProbs, + Options: opts, }, func(cr llm.CompletionResponse) { + fmt.Printf("banana: %#v\n", cr) res := api.GenerateResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), @@ -311,6 +313,13 @@ func (s *Server) GenerateHandler(c *gin.Context) { EvalDuration: cr.EvalDuration, }, } + for _, p := range cr.LogProbs { + res.LogProbs = append(res.LogProbs, api.TokenProbs{ + TokenID: p.TokenID, + LogProb: p.LogProb, + Token: p.Token, + }) + } if _, err := sb.WriteString(cr.Content); err != nil { ch <- gin.H{"error": err.Error()} @@ -1466,10 +1475,11 @@ func (s *Server) ChatHandler(c *gin.Context) { var sb strings.Builder var toolCallIndex int = 0 if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ - Prompt: prompt, - Images: images, - Format: req.Format, - Options: opts, + Prompt: prompt, + Images: images, + Format: req.Format, + LogProbs: req.LogProbs, + Options: opts, }, func(r llm.CompletionResponse) { res := api.ChatResponse{ Model: req.Model, @@ -1484,6 +1494,13 @@ func (s *Server) ChatHandler(c *gin.Context) { EvalDuration: r.EvalDuration, }, } + for _, p := range r.LogProbs { + res.LogProbs = append(res.LogProbs, api.TokenProbs{ + TokenID: p.TokenID, + LogProb: p.LogProb, + Token: p.Token, + }) + } if r.Done { res.TotalDuration = time.Since(checkpointStart)