send completion response on chan

This commit is contained in:
Bruce MacDonald 2025-02-12 17:03:52 -08:00
parent 7d16ec8fe8
commit 6dfcdec2da

View File

@ -61,7 +61,7 @@ type Sequence struct {
crossAttention bool
// channel to send responses over
responses chan string
responses chan CompletionResponse
// channel to stop decoding (such as if the remote connection is closed)
quit chan bool
@ -153,7 +153,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
startProcessingTime: startTime,
numPredict: params.numPredict,
pendingResponses: make([]string, 0),
responses: make(chan string, 100),
responses: make(chan CompletionResponse, 100),
quit: make(chan bool, 1),
embedding: make(chan []float32, 1),
samplingCtx: sc,
@ -281,7 +281,7 @@ func flushPending(seq *Sequence) bool {
if len(seq.pendingResponses) == 0 {
return true
}
joined := strings.Join(seq.pendingResponses, "")
content := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = []string{}
// Check if there are any partial UTF-8 characters remaining.
@ -290,8 +290,8 @@ func flushPending(seq *Sequence) bool {
// - Sequence is ending, e.g. generation limit has been hit
// - Invalid characters in the middle of a string
// This is a stricter check to ensure we never output invalid Unicode.
for !utf8.ValidString(joined) {
joined = joined[:len(joined)-1]
for !utf8.ValidString(content) {
content = content[:len(content)-1]
}
// Add logits if requested and available
@ -302,7 +302,9 @@ func flushPending(seq *Sequence) bool {
}
select {
case seq.responses <- joined:
case seq.responses <- CompletionResponse{
Content: content,
}:
return true
case <-seq.quit:
return false
@ -755,11 +757,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
case <-r.Context().Done():
close(seq.quit)
return
case content, ok := <-seq.responses:
case resp, ok := <-seq.responses:
if ok {
if err := json.NewEncoder(w).Encode(&CompletionResponse{
Content: content,
}); err != nil {
if err := json.NewEncoder(w).Encode(&resp); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
close(seq.quit)
return