diff --git a/server/routes.go b/server/routes.go index 0aecba066..e6be65d04 100644 --- a/server/routes.go +++ b/server/routes.go @@ -259,38 +259,42 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } + truncate := true + + if req.Truncate != nil && !*req.Truncate { + truncate = false + } + if req.Truncate == nil { truncate := true req.Truncate = &truncate } - reqEmbed := []string{} + var input []string - switch embeddings := req.Input.(type) { + switch i := req.Input.(type) { case string: - if embeddings == "" { - c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}}) - return + if len(i) > 0 { + input = append(input, i) } - reqEmbed = []string{embeddings} case []any: - if len(embeddings) == 0 { - c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}}) - return - } - - for _, v := range embeddings { + for _, v := range i { if _, ok := v.(string); !ok { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"}) return } - reqEmbed = append(reqEmbed, v.(string)) + input = append(input, v.(string)) } default: c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"}) return } + if len(input) == 0 { + c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}}) + return + } + r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) if err != nil { handleScheduleError(c, req.Model, err) @@ -303,8 +307,8 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } - reqEmbedArray := make([]string, len(reqEmbed)) - for i, s := range reqEmbed { + reqEmbedArray := make([]string, len(input)) + for i, s := range input { tokens, err := r.Tokenize(c.Request.Context(), s) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -313,17 +317,17 @@ func (s *Server) EmbedHandler(c *gin.Context) { ctxLen := min(opts.NumCtx, int(kvData.ContextLength())) if len(tokens) > ctxLen { - if *req.Truncate { - tokens = tokens[:ctxLen] - s, err = r.Detokenize(c.Request.Context(), tokens) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - } else { + if !truncate { c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"}) return } + + tokens = tokens[:ctxLen] + s, err = r.Detokenize(c.Request.Context(), tokens) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } } reqEmbedArray[i] = s @@ -331,7 +335,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { embeddings, err := r.Embed(c.Request.Context(), reqEmbedArray) if err != nil { - slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) + slog.Error("embedding generation failed", "error", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) return } @@ -1030,7 +1034,7 @@ func (s *Server) GenerateRoutes() http.Handler { r.POST("/api/generate", s.GenerateHandler) r.POST("/api/chat", s.ChatHandler) r.POST("/api/embed", s.EmbedHandler) - r.POST("/api/embeddings", s.EmbeddingsHandler) // legacy + r.POST("/api/embeddings", s.EmbeddingsHandler) r.POST("/api/create", s.CreateModelHandler) r.POST("/api/push", s.PushModelHandler) r.POST("/api/copy", s.CopyModelHandler)