diff --git a/server/routes.go b/server/routes.go index 264f54027..785618de0 100644 --- a/server/routes.go +++ b/server/routes.go @@ -304,26 +304,27 @@ func (s *Server) EmbedHandler(c *gin.Context) { } reqEmbedArray := make([]string, len(reqEmbed)) - for i, v := range reqEmbed { - s, err := func(v string, truncate bool) (string, error) { - tokens, err := r.Tokenize(c.Request.Context(), v) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return "", err - } + for i, s := range reqEmbed { + tokens, err := r.Tokenize(c.Request.Context(), s) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } - ctxLen := min(opts.NumCtx, int(kvData.ContextLength())) - if len(tokens) > ctxLen { - if truncate { - tokens = tokens[:ctxLen] - return r.Detokenize(c.Request.Context(), tokens) - } else { - return "", fmt.Errorf("input length exceeds maximum context length") + ctxLen := min(opts.NumCtx, int(kvData.ContextLength())) + if len(tokens) > ctxLen { + if *req.Truncate { + tokens = tokens[:ctxLen] + s, err = r.Detokenize(c.Request.Context(), tokens) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return } + } else { + c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"}) + return } - - return v, nil - }(v, *req.Truncate) + } if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})