diff --git a/server/routes.go b/server/routes.go index 785618de0..ae8939e9e 100644 --- a/server/routes.go +++ b/server/routes.go @@ -18,6 +18,7 @@ import ( "path/filepath" "slices" "strings" + "sync" "syscall" "time" @@ -304,34 +305,56 @@ func (s *Server) EmbedHandler(c *gin.Context) { } reqEmbedArray := make([]string, len(reqEmbed)) + errCh := make(chan error, len(reqEmbed)) + var wg sync.WaitGroup + var mu sync.Mutex + sem := make(chan struct{}, 5) for i, s := range reqEmbed { - tokens, err := r.Tokenize(c.Request.Context(), s) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - ctxLen := min(opts.NumCtx, int(kvData.ContextLength())) - if len(tokens) > ctxLen { - if *req.Truncate { - tokens = tokens[:ctxLen] - s, err = r.Detokenize(c.Request.Context(), tokens) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - } else { - c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"}) + wg.Add(1) + sem <- struct{}{} + go func(i int, s string) { + defer wg.Done() + defer func() { <-sem }() + tokens, err := r.Tokenize(c.Request.Context(), s) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - } + ctxLen := min(opts.NumCtx, int(kvData.ContextLength())) + if len(tokens) > ctxLen { + if *req.Truncate { + tokens = tokens[:ctxLen] + s, err = r.Detokenize(c.Request.Context(), tokens) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } else { + c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"}) + return + } + } + if err != nil { + errCh <- err + return + } + mu.Lock() + reqEmbedArray[i] = s + mu.Unlock() + }(i, s) + } + go func() { + wg.Wait() + close(errCh) + }() + for err := range errCh { if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - reqEmbedArray[i] = s } + embeddings, err := r.Embed(c.Request.Context(), reqEmbedArray) if err != nil {