parallelized
This commit is contained in:
parent
1f3aefd323
commit
7cddd6d741
@ -18,6 +18,7 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -304,34 +305,56 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
reqEmbedArray := make([]string, len(reqEmbed))
|
reqEmbedArray := make([]string, len(reqEmbed))
|
||||||
|
errCh := make(chan error, len(reqEmbed))
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var mu sync.Mutex
|
||||||
|
sem := make(chan struct{}, 5)
|
||||||
for i, s := range reqEmbed {
|
for i, s := range reqEmbed {
|
||||||
tokens, err := r.Tokenize(c.Request.Context(), s)
|
wg.Add(1)
|
||||||
if err != nil {
|
sem <- struct{}{}
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
go func(i int, s string) {
|
||||||
return
|
defer wg.Done()
|
||||||
}
|
defer func() { <-sem }()
|
||||||
|
tokens, err := r.Tokenize(c.Request.Context(), s)
|
||||||
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
if err != nil {
|
||||||
if len(tokens) > ctxLen {
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
if *req.Truncate {
|
|
||||||
tokens = tokens[:ctxLen]
|
|
||||||
s, err = r.Detokenize(c.Request.Context(), tokens)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
||||||
|
if len(tokens) > ctxLen {
|
||||||
|
if *req.Truncate {
|
||||||
|
tokens = tokens[:ctxLen]
|
||||||
|
s, err = r.Detokenize(c.Request.Context(), tokens)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
errCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
mu.Lock()
|
||||||
|
reqEmbedArray[i] = s
|
||||||
|
mu.Unlock()
|
||||||
|
}(i, s)
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
wg.Wait()
|
||||||
|
close(errCh)
|
||||||
|
}()
|
||||||
|
for err := range errCh {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
reqEmbedArray[i] = s
|
|
||||||
}
|
}
|
||||||
|
|
||||||
embeddings, err := r.Embed(c.Request.Context(), reqEmbedArray)
|
embeddings, err := r.Embed(c.Request.Context(), reqEmbedArray)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user