diff --git a/llm/server.go b/llm/server.go index e6046db60..a2bc1548f 100644 --- a/llm/server.go +++ b/llm/server.go @@ -675,9 +675,32 @@ type CompletionRequest struct { Grammar string // set before sending the request to the subprocess } +// DoneReason represents the reason why a completion response is done +type DoneReason int + +const ( + // DoneReasonStop indicates the completion stopped naturally + DoneReasonStop DoneReason = iota + // DoneReasonLength indicates the completion stopped due to length limits + DoneReasonLength + // DoneReasonConnectionClosed indicates the completion stopped due to the connection being closed + DoneReasonConnectionClosed +) + +func (d DoneReason) String() string { + switch d { + case DoneReasonLength: + return "length" + case DoneReasonStop: + return "stop" + default: + return "" // closed + } +} + type CompletionResponse struct { Content string `json:"content"` - DoneReason string `json:"done_reason"` + DoneReason DoneReason `json:"done_reason"` Done bool `json:"done"` PromptEvalCount int `json:"prompt_eval_count"` PromptEvalDuration time.Duration `json:"prompt_eval_duration"` @@ -786,7 +809,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu continue } - // slog.Debug("got line", "line", string(line)) evt, ok := bytes.CutPrefix(line, []byte("data: ")) if !ok { evt = line diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index a4264f5fc..d8169be40 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -83,7 +83,7 @@ type Sequence struct { // true if an embedding are to be returned instead of text generation embeddingOnly bool - doneReason string + doneReason llm.DoneReason // Metrics startProcessingTime time.Time @@ -301,7 +301,7 @@ func flushPending(seq *Sequence) bool { } } -func (s *Server) removeSequence(seqIndex int, reason string) { +func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) { seq := s.seqs[seqIndex] flushPending(seq) @@ -380,7 +380,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) // if past the num predict limit if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { - s.removeSequence(seqIdx, "limit") + s.removeSequence(seqIdx, llm.DoneReasonLength) continue } @@ -482,7 +482,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) } seq.embedding <- embed - s.removeSequence(i, "") + s.removeSequence(i, llm.DoneReasonStop) continue } @@ -499,7 +499,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) // as it's important for the /api/generate context // seq.responses <- piece - s.removeSequence(i, "stop") + s.removeSequence(i, llm.DoneReasonStop) continue } @@ -530,7 +530,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) } seq.cache.Inputs = seq.cache.Inputs[:tokenLen] - s.removeSequence(i, "stop") + s.removeSequence(i, llm.DoneReasonStop) continue } @@ -543,7 +543,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) } if !flushPending(seq) { - s.removeSequence(i, "connection") + s.removeSequence(i, llm.DoneReasonConnectionClosed) } } @@ -657,14 +657,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { flusher.Flush() } else { - // Send the final response - doneReason := "stop" - if seq.doneReason == "limit" { - doneReason = "length" - } if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ Done: true, - DoneReason: doneReason, + DoneReason: seq.doneReason, PromptEvalCount: seq.numPromptInputs, PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), EvalCount: seq.numDecoded, diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index f3286abae..7b7e09402 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -82,7 +82,7 @@ type Sequence struct { // true if an embedding are to be returned instead of text generation embeddingOnly bool - doneReason string + doneReason llm.DoneReason // Metrics startProcessingTime time.Time @@ -341,7 +341,7 @@ func flushPending(seq *Sequence) bool { } } -func (s *Server) removeSequence(seqIndex int, reason string) { +func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) { seq := s.seqs[seqIndex] flushPending(seq) @@ -391,7 +391,7 @@ func (s *Server) processBatch() error { // if past the num predict limit if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { - s.removeSequence(seqIdx, "limit") + s.removeSequence(seqIdx, llm.DoneReasonLength) continue } @@ -510,7 +510,7 @@ func (s *Server) processBatch() error { if seq.embeddingOnly { // TODO(jessegross): Embedding support slog.Warn("generation of embedding outputs not yet supported") - s.removeSequence(i, "") + s.removeSequence(i, llm.DoneReasonStop) continue } @@ -528,7 +528,7 @@ func (s *Server) processBatch() error { // as it's important for the /api/generate context // seq.responses <- piece - s.removeSequence(i, "stop") + s.removeSequence(i, llm.DoneReasonStop) continue } @@ -564,7 +564,7 @@ func (s *Server) processBatch() error { } seq.cache.Inputs = seq.cache.Inputs[:tokenLen] - s.removeSequence(i, "stop") + s.removeSequence(i, llm.DoneReasonStop) continue } @@ -577,7 +577,7 @@ func (s *Server) processBatch() error { } if !flushPending(seq) { - s.removeSequence(i, "connection") + s.removeSequence(i, llm.DoneReasonConnectionClosed) } } @@ -690,14 +690,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { flusher.Flush() } else { - // Send the final response - doneReason := "stop" - if seq.doneReason == "limit" { - doneReason = "length" - } if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ Done: true, - DoneReason: doneReason, + DoneReason: seq.doneReason, PromptEvalCount: seq.numPromptInputs, PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), EvalCount: seq.numPredicted, diff --git a/server/routes.go b/server/routes.go index eee34033e..906426b18 100644 --- a/server/routes.go +++ b/server/routes.go @@ -308,11 +308,10 @@ func (s *Server) GenerateHandler(c *gin.Context) { Options: opts, }, func(cr llm.CompletionResponse) { res := api.GenerateResponse{ - Model: req.Model, - CreatedAt: time.Now().UTC(), - Response: cr.Content, - Done: cr.Done, - DoneReason: cr.DoneReason, + Model: req.Model, + CreatedAt: time.Now().UTC(), + Response: cr.Content, + Done: cr.Done, Metrics: api.Metrics{ PromptEvalCount: cr.PromptEvalCount, PromptEvalDuration: cr.PromptEvalDuration, @@ -326,6 +325,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { } if cr.Done { + res.DoneReason = cr.DoneReason.String() res.TotalDuration = time.Since(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart) @@ -1533,11 +1533,10 @@ func (s *Server) ChatHandler(c *gin.Context) { Options: opts, }, func(r llm.CompletionResponse) { res := api.ChatResponse{ - Model: req.Model, - CreatedAt: time.Now().UTC(), - Message: api.Message{Role: "assistant", Content: r.Content}, - Done: r.Done, - DoneReason: r.DoneReason, + Model: req.Model, + CreatedAt: time.Now().UTC(), + Message: api.Message{Role: "assistant", Content: r.Content}, + Done: r.Done, Metrics: api.Metrics{ PromptEvalCount: r.PromptEvalCount, PromptEvalDuration: r.PromptEvalDuration, @@ -1547,6 +1546,7 @@ func (s *Server) ChatHandler(c *gin.Context) { } if r.Done { + res.DoneReason = r.DoneReason.String() res.TotalDuration = time.Since(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart) } diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index aa263bf97..f219387c3 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -58,7 +58,7 @@ func TestGenerateChat(t *testing.T) { mock := mockRunner{ CompletionResponse: llm.CompletionResponse{ Done: true, - DoneReason: "stop", + DoneReason: llm.DoneReasonStop, PromptEvalCount: 1, PromptEvalDuration: 1, EvalCount: 1, @@ -401,7 +401,7 @@ func TestGenerateChat(t *testing.T) { mock.CompletionResponse = llm.CompletionResponse{ Content: `{"name":"get_weather","arguments":{"location":"Seattle, WA","unit":"celsius"}}`, Done: true, - DoneReason: "done", + DoneReason: llm.DoneReasonStop, PromptEvalCount: 1, PromptEvalDuration: 1, EvalCount: 1, @@ -519,7 +519,7 @@ func TestGenerateChat(t *testing.T) { { Content: `, WA","unit":"celsius"}}`, Done: true, - DoneReason: "tool_call", + DoneReason: llm.DoneReasonStop, PromptEvalCount: 3, PromptEvalDuration: 1, }, @@ -594,7 +594,7 @@ func TestGenerate(t *testing.T) { mock := mockRunner{ CompletionResponse: llm.CompletionResponse{ Done: true, - DoneReason: "stop", + DoneReason: llm.DoneReasonStop, PromptEvalCount: 1, PromptEvalDuration: 1, EvalCount: 1,