From 1daac5265179daaa0b0632669aa7df380ab8e274 Mon Sep 17 00:00:00 2001 From: Roy Han Date: Mon, 1 Jul 2024 11:55:16 -0700 Subject: [PATCH] Truncation --- server/routes.go | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/server/routes.go b/server/routes.go index 1051a3c2a..5a810b274 100644 --- a/server/routes.go +++ b/server/routes.go @@ -394,6 +394,21 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } + truncate := func(s string) (string, error) { + tokens, err := runner.llama.Tokenize(c.Request.Context(), s) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return "", err + } + + if len(tokens) > opts.NumCtx { + tokens = tokens[len(tokens)-opts.NumCtx:] + return runner.llama.Detokenize(c.Request.Context(), tokens) + } + + return s, nil + } + embeddings := [][]float64{} switch reqEmbed := req.Input.(type) { @@ -402,6 +417,13 @@ func (s *Server) EmbedHandler(c *gin.Context) { c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}}) return } + if *req.Truncate { + reqEmbed, err = truncate(reqEmbed) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed}) case []any: if reqEmbed == nil { @@ -412,6 +434,13 @@ func (s *Server) EmbedHandler(c *gin.Context) { reqEmbedArray := make([]string, len(reqEmbed)) for i, v := range reqEmbed { if s, ok := v.(string); ok { + if *req.Truncate { + s, err = truncate(s) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } reqEmbedArray[i] = s } else { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})