parallelized

This commit is contained in:
Roy Han 2024-07-12 16:08:12 -07:00
parent 1f3aefd323
commit 7cddd6d741

View File

@ -18,6 +18,7 @@ import (
"path/filepath"
"slices"
"strings"
"sync"
"syscall"
"time"
@ -304,34 +305,56 @@ func (s *Server) EmbedHandler(c *gin.Context) {
}
reqEmbedArray := make([]string, len(reqEmbed))
errCh := make(chan error, len(reqEmbed))
var wg sync.WaitGroup
var mu sync.Mutex
sem := make(chan struct{}, 5)
for i, s := range reqEmbed {
tokens, err := r.Tokenize(c.Request.Context(), s)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if len(tokens) > ctxLen {
if *req.Truncate {
tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
} else {
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
wg.Add(1)
sem <- struct{}{}
go func(i int, s string) {
defer wg.Done()
defer func() { <-sem }()
tokens, err := r.Tokenize(c.Request.Context(), s)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if len(tokens) > ctxLen {
if *req.Truncate {
tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
} else {
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
return
}
}
if err != nil {
errCh <- err
return
}
mu.Lock()
reqEmbedArray[i] = s
mu.Unlock()
}(i, s)
}
go func() {
wg.Wait()
close(errCh)
}()
for err := range errCh {
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
reqEmbedArray[i] = s
}
embeddings, err := r.Embed(c.Request.Context(), reqEmbedArray)
if err != nil {