This commit is contained in:
Roy Han 2024-07-11 17:28:55 -07:00
parent 694388db90
commit dbe9527305

View File

@ -8,7 +8,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/fs"
"log/slog" "log/slog"
"math" "math"
"net" "net"
@ -260,33 +259,11 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return return
} }
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
if req.Truncate == nil { if req.Truncate == nil {
truncate := true truncate := true
req.Truncate = &truncate req.Truncate = &truncate
} }
model, err := GetModel(req.Model)
if err != nil {
var pErr *fs.PathError
if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
opts, err := modelOptions(model, req.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
reqEmbed := []string{} reqEmbed := []string{}
switch embeddings := req.Input.(type) { switch embeddings := req.Input.(type) {
@ -314,20 +291,22 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return return
} }
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
if err != nil { if err != nil {
handleScheduleError(c, req.Model, err) handleScheduleError(c, req.Model, err)
return return
} }
kvData, err := getKVData(model.ModelPath, false) kvData, err := getKVData(m.ModelPath, false)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
checkFit := func(s string, truncate bool) (string, error) { reqEmbedArray := make([]string, len(reqEmbed))
tokens, err := r.Tokenize(c.Request.Context(), s) for i, v := range reqEmbed {
s, err := func(v string, truncate bool) (string, error) {
tokens, err := r.Tokenize(c.Request.Context(), v)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return "", err return "", err
@ -343,12 +322,9 @@ func (s *Server) EmbedHandler(c *gin.Context) {
} }
} }
return s, nil return v, nil
} }(v, *req.Truncate)
reqEmbedArray := make([]string, len(reqEmbed))
for i, v := range reqEmbed {
s, err := checkFit(v, *req.Truncate)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return