clean up
This commit is contained in:
parent
7e313e5964
commit
424f3f81a9
@ -259,38 +259,42 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
truncate := true
|
||||
|
||||
if req.Truncate != nil && !*req.Truncate {
|
||||
truncate = false
|
||||
}
|
||||
|
||||
if req.Truncate == nil {
|
||||
truncate := true
|
||||
req.Truncate = &truncate
|
||||
}
|
||||
|
||||
reqEmbed := []string{}
|
||||
var input []string
|
||||
|
||||
switch embeddings := req.Input.(type) {
|
||||
switch i := req.Input.(type) {
|
||||
case string:
|
||||
if embeddings == "" {
|
||||
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
||||
return
|
||||
if len(i) > 0 {
|
||||
input = append(input, i)
|
||||
}
|
||||
reqEmbed = []string{embeddings}
|
||||
case []any:
|
||||
if len(embeddings) == 0 {
|
||||
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
||||
return
|
||||
}
|
||||
|
||||
for _, v := range embeddings {
|
||||
for _, v := range i {
|
||||
if _, ok := v.(string); !ok {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
||||
return
|
||||
}
|
||||
reqEmbed = append(reqEmbed, v.(string))
|
||||
input = append(input, v.(string))
|
||||
}
|
||||
default:
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
||||
return
|
||||
}
|
||||
|
||||
if len(input) == 0 {
|
||||
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
||||
return
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
@ -303,8 +307,8 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
reqEmbedArray := make([]string, len(reqEmbed))
|
||||
for i, s := range reqEmbed {
|
||||
reqEmbedArray := make([]string, len(input))
|
||||
for i, s := range input {
|
||||
tokens, err := r.Tokenize(c.Request.Context(), s)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@ -313,17 +317,17 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
|
||||
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
||||
if len(tokens) > ctxLen {
|
||||
if *req.Truncate {
|
||||
tokens = tokens[:ctxLen]
|
||||
s, err = r.Detokenize(c.Request.Context(), tokens)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if !truncate {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
|
||||
return
|
||||
}
|
||||
|
||||
tokens = tokens[:ctxLen]
|
||||
s, err = r.Detokenize(c.Request.Context(), tokens)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
reqEmbedArray[i] = s
|
||||
@ -331,7 +335,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
embeddings, err := r.Embed(c.Request.Context(), reqEmbedArray)
|
||||
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
||||
slog.Error("embedding generation failed", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||
return
|
||||
}
|
||||
@ -1030,7 +1034,7 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||
r.POST("/api/generate", s.GenerateHandler)
|
||||
r.POST("/api/chat", s.ChatHandler)
|
||||
r.POST("/api/embed", s.EmbedHandler)
|
||||
r.POST("/api/embeddings", s.EmbeddingsHandler) // legacy
|
||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||
r.POST("/api/create", s.CreateModelHandler)
|
||||
r.POST("/api/push", s.PushModelHandler)
|
||||
r.POST("/api/copy", s.CopyModelHandler)
|
||||
|
Loading…
x
Reference in New Issue
Block a user