clean up
This commit is contained in:
parent
694388db90
commit
dbe9527305
@ -8,7 +8,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net"
|
||||
@ -260,33 +259,11 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Truncate == nil {
|
||||
truncate := true
|
||||
req.Truncate = &truncate
|
||||
}
|
||||
|
||||
model, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
var pErr *fs.PathError
|
||||
if errors.As(err, &pErr) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
opts, err := modelOptions(model, req.Options)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
reqEmbed := []string{}
|
||||
|
||||
switch embeddings := req.Input.(type) {
|
||||
@ -314,41 +291,40 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
kvData, err := getKVData(model.ModelPath, false)
|
||||
kvData, err := getKVData(m.ModelPath, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
checkFit := func(s string, truncate bool) (string, error) {
|
||||
tokens, err := r.Tokenize(c.Request.Context(), s)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return "", err
|
||||
}
|
||||
|
||||
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
||||
if len(tokens) > ctxLen {
|
||||
if truncate {
|
||||
tokens = tokens[:ctxLen]
|
||||
return r.Detokenize(c.Request.Context(), tokens)
|
||||
} else {
|
||||
return "", fmt.Errorf("input length exceeds maximum context length")
|
||||
}
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
reqEmbedArray := make([]string, len(reqEmbed))
|
||||
for i, v := range reqEmbed {
|
||||
s, err := checkFit(v, *req.Truncate)
|
||||
s, err := func(v string, truncate bool) (string, error) {
|
||||
tokens, err := r.Tokenize(c.Request.Context(), v)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return "", err
|
||||
}
|
||||
|
||||
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
||||
if len(tokens) > ctxLen {
|
||||
if truncate {
|
||||
tokens = tokens[:ctxLen]
|
||||
return r.Detokenize(c.Request.Context(), tokens)
|
||||
} else {
|
||||
return "", fmt.Errorf("input length exceeds maximum context length")
|
||||
}
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}(v, *req.Truncate)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
|
Loading…
x
Reference in New Issue
Block a user