From dbe95273053816025befb73ac2048ffb5eb2ec08 Mon Sep 17 00:00:00 2001 From: Roy Han Date: Thu, 11 Jul 2024 17:28:55 -0700 Subject: [PATCH] clean up --- server/routes.go | 68 ++++++++++++++++-------------------------------- 1 file changed, 22 insertions(+), 46 deletions(-) diff --git a/server/routes.go b/server/routes.go index fe46ab038..264f54027 100644 --- a/server/routes.go +++ b/server/routes.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "io" - "io/fs" "log/slog" "math" "net" @@ -260,33 +259,11 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } - if req.Model == "" { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) - return - } - if req.Truncate == nil { truncate := true req.Truncate = &truncate } - model, err := GetModel(req.Model) - if err != nil { - var pErr *fs.PathError - if errors.As(err, &pErr) { - c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) - return - } - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - opts, err := modelOptions(model, req.Options) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - reqEmbed := []string{} switch embeddings := req.Input.(type) { @@ -314,41 +291,40 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } - r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) + r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) if err != nil { handleScheduleError(c, req.Model, err) return } - kvData, err := getKVData(model.ModelPath, false) + kvData, err := getKVData(m.ModelPath, false) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - checkFit := func(s string, truncate bool) (string, error) { - tokens, err := r.Tokenize(c.Request.Context(), s) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return "", err - } - - ctxLen := min(opts.NumCtx, int(kvData.ContextLength())) - if len(tokens) > ctxLen { - if truncate { - tokens = tokens[:ctxLen] - return r.Detokenize(c.Request.Context(), tokens) - } else { - return "", fmt.Errorf("input length exceeds maximum context length") - } - } - - return s, nil - } - reqEmbedArray := make([]string, len(reqEmbed)) for i, v := range reqEmbed { - s, err := checkFit(v, *req.Truncate) + s, err := func(v string, truncate bool) (string, error) { + tokens, err := r.Tokenize(c.Request.Context(), v) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return "", err + } + + ctxLen := min(opts.NumCtx, int(kvData.ContextLength())) + if len(tokens) > ctxLen { + if truncate { + tokens = tokens[:ctxLen] + return r.Detokenize(c.Request.Context(), tokens) + } else { + return "", fmt.Errorf("input length exceeds maximum context length") + } + } + + return v, nil + }(v, *req.Truncate) + if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return