diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 8e9698439..eda3d109b 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -61,7 +61,7 @@ type Sequence struct { crossAttention bool // channel to send responses over - responses chan string + responses chan CompletionResponse // channel to stop decoding (such as if the remote connection is closed) quit chan bool @@ -153,7 +153,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen startProcessingTime: startTime, numPredict: params.numPredict, pendingResponses: make([]string, 0), - responses: make(chan string, 100), + responses: make(chan CompletionResponse, 100), quit: make(chan bool, 1), embedding: make(chan []float32, 1), samplingCtx: sc, @@ -281,7 +281,7 @@ func flushPending(seq *Sequence) bool { if len(seq.pendingResponses) == 0 { return true } - joined := strings.Join(seq.pendingResponses, "") + content := strings.Join(seq.pendingResponses, "") seq.pendingResponses = []string{} // Check if there are any partial UTF-8 characters remaining. @@ -290,8 +290,8 @@ func flushPending(seq *Sequence) bool { // - Sequence is ending, e.g. generation limit has been hit // - Invalid characters in the middle of a string // This is a stricter check to ensure we never output invalid Unicode. - for !utf8.ValidString(joined) { - joined = joined[:len(joined)-1] + for !utf8.ValidString(content) { + content = content[:len(content)-1] } // Add logits if requested and available @@ -302,7 +302,9 @@ func flushPending(seq *Sequence) bool { } select { - case seq.responses <- joined: + case seq.responses <- CompletionResponse{ + Content: content, + }: return true case <-seq.quit: return false @@ -755,11 +757,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { case <-r.Context().Done(): close(seq.quit) return - case content, ok := <-seq.responses: + case resp, ok := <-seq.responses: if ok { - if err := json.NewEncoder(w).Encode(&CompletionResponse{ - Content: content, - }); err != nil { + if err := json.NewEncoder(w).Encode(&resp); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) close(seq.quit) return