This commit is contained in:
Roy Han 2024-07-11 17:28:55 -07:00
parent 694388db90
commit dbe9527305

View File

@ -8,7 +8,6 @@ import (
"errors"
"fmt"
"io"
"io/fs"
"log/slog"
"math"
"net"
@ -260,33 +259,11 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
if req.Truncate == nil {
truncate := true
req.Truncate = &truncate
}
model, err := GetModel(req.Model)
if err != nil {
var pErr *fs.PathError
if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
opts, err := modelOptions(model, req.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
reqEmbed := []string{}
switch embeddings := req.Input.(type) {
@ -314,41 +291,40 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
}
kvData, err := getKVData(model.ModelPath, false)
kvData, err := getKVData(m.ModelPath, false)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
checkFit := func(s string, truncate bool) (string, error) {
tokens, err := r.Tokenize(c.Request.Context(), s)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return "", err
}
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if len(tokens) > ctxLen {
if truncate {
tokens = tokens[:ctxLen]
return r.Detokenize(c.Request.Context(), tokens)
} else {
return "", fmt.Errorf("input length exceeds maximum context length")
}
}
return s, nil
}
reqEmbedArray := make([]string, len(reqEmbed))
for i, v := range reqEmbed {
s, err := checkFit(v, *req.Truncate)
s, err := func(v string, truncate bool) (string, error) {
tokens, err := r.Tokenize(c.Request.Context(), v)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return "", err
}
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if len(tokens) > ctxLen {
if truncate {
tokens = tokens[:ctxLen]
return r.Detokenize(c.Request.Context(), tokens)
} else {
return "", fmt.Errorf("input length exceeds maximum context length")
}
}
return v, nil
}(v, *req.Truncate)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return