This commit is contained in:
Roy Han 2024-07-12 17:28:08 -07:00
parent 7e313e5964
commit 424f3f81a9

View File

@ -259,38 +259,42 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
truncate := true
if req.Truncate != nil && !*req.Truncate {
truncate = false
}
if req.Truncate == nil {
truncate := true
req.Truncate = &truncate
}
reqEmbed := []string{}
var input []string
switch embeddings := req.Input.(type) {
switch i := req.Input.(type) {
case string:
if embeddings == "" {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return
if len(i) > 0 {
input = append(input, i)
}
reqEmbed = []string{embeddings}
case []any:
if len(embeddings) == 0 {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return
}
for _, v := range embeddings {
for _, v := range i {
if _, ok := v.(string); !ok {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return
}
reqEmbed = append(reqEmbed, v.(string))
input = append(input, v.(string))
}
default:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return
}
if len(input) == 0 {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
@ -303,8 +307,8 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
reqEmbedArray := make([]string, len(reqEmbed))
for i, s := range reqEmbed {
reqEmbedArray := make([]string, len(input))
for i, s := range input {
tokens, err := r.Tokenize(c.Request.Context(), s)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@ -313,17 +317,17 @@ func (s *Server) EmbedHandler(c *gin.Context) {
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if len(tokens) > ctxLen {
if *req.Truncate {
tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
} else {
if !truncate {
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
return
}
tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
reqEmbedArray[i] = s
@ -331,7 +335,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
embeddings, err := r.Embed(c.Request.Context(), reqEmbedArray)
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
slog.Error("embedding generation failed", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return
}
@ -1030,7 +1034,7 @@ func (s *Server) GenerateRoutes() http.Handler {
r.POST("/api/generate", s.GenerateHandler)
r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embed", s.EmbedHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler) // legacy
r.POST("/api/embeddings", s.EmbeddingsHandler)
r.POST("/api/create", s.CreateModelHandler)
r.POST("/api/push", s.PushModelHandler)
r.POST("/api/copy", s.CopyModelHandler)