send completion response on chan

This commit is contained in:
Bruce MacDonald 2025-02-12 17:03:52 -08:00
parent 7d16ec8fe8
commit 6dfcdec2da

View File

@ -61,7 +61,7 @@ type Sequence struct {
crossAttention bool crossAttention bool
// channel to send responses over // channel to send responses over
responses chan string responses chan CompletionResponse
// channel to stop decoding (such as if the remote connection is closed) // channel to stop decoding (such as if the remote connection is closed)
quit chan bool quit chan bool
@ -153,7 +153,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
startProcessingTime: startTime, startProcessingTime: startTime,
numPredict: params.numPredict, numPredict: params.numPredict,
pendingResponses: make([]string, 0), pendingResponses: make([]string, 0),
responses: make(chan string, 100), responses: make(chan CompletionResponse, 100),
quit: make(chan bool, 1), quit: make(chan bool, 1),
embedding: make(chan []float32, 1), embedding: make(chan []float32, 1),
samplingCtx: sc, samplingCtx: sc,
@ -281,7 +281,7 @@ func flushPending(seq *Sequence) bool {
if len(seq.pendingResponses) == 0 { if len(seq.pendingResponses) == 0 {
return true return true
} }
joined := strings.Join(seq.pendingResponses, "") content := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = []string{} seq.pendingResponses = []string{}
// Check if there are any partial UTF-8 characters remaining. // Check if there are any partial UTF-8 characters remaining.
@ -290,8 +290,8 @@ func flushPending(seq *Sequence) bool {
// - Sequence is ending, e.g. generation limit has been hit // - Sequence is ending, e.g. generation limit has been hit
// - Invalid characters in the middle of a string // - Invalid characters in the middle of a string
// This is a stricter check to ensure we never output invalid Unicode. // This is a stricter check to ensure we never output invalid Unicode.
for !utf8.ValidString(joined) { for !utf8.ValidString(content) {
joined = joined[:len(joined)-1] content = content[:len(content)-1]
} }
// Add logits if requested and available // Add logits if requested and available
@ -302,7 +302,9 @@ func flushPending(seq *Sequence) bool {
} }
select { select {
case seq.responses <- joined: case seq.responses <- CompletionResponse{
Content: content,
}:
return true return true
case <-seq.quit: case <-seq.quit:
return false return false
@ -755,11 +757,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
case <-r.Context().Done(): case <-r.Context().Done():
close(seq.quit) close(seq.quit)
return return
case content, ok := <-seq.responses: case resp, ok := <-seq.responses:
if ok { if ok {
if err := json.NewEncoder(w).Encode(&CompletionResponse{ if err := json.NewEncoder(w).Encode(&resp); err != nil {
Content: content,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
close(seq.quit) close(seq.quit)
return return