diff --git a/llm/ext_server/server.cpp b/llm/ext_server/server.cpp index 0eb69538c..e8a076c43 100644 --- a/llm/ext_server/server.cpp +++ b/llm/ext_server/server.cpp @@ -3202,19 +3202,18 @@ int main(int argc, char **argv) { // get the result task_result result = llama.queue_results.recv(id_task); llama.queue_results.remove_waiting_task_id(id_task); - if (!result.error) { - responses = result.result_json.value("results", std::vector{result.result_json}); - json embeddings = json::array(); - for (auto & elem : responses) { - embeddings.push_back(elem.at("embedding")); - } - // send the result - json result = json{{"embedding", embeddings}}; - return res.set_content(result.dump(), "application/json; charset=utf-8"); - } else { - // return error + if (result.error) { return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); } + + responses = result.result_json.value("results", std::vector{result.result_json}); + json embeddings = json::array(); + for (auto & elem : responses) { + embeddings.push_back(elem.at("embedding")); + } + // send the result + json embedding_res = json{{"embedding", embeddings}}; + return res.set_content(embedding_res.dump(), "application/json; charset=utf-8"); } }); diff --git a/server/routes.go b/server/routes.go index a975e94fe..782028c0a 100644 --- a/server/routes.go +++ b/server/routes.go @@ -10,6 +10,7 @@ import ( "io" "io/fs" "log/slog" + "math" "net" "net/http" "net/netip" @@ -21,7 +22,6 @@ import ( "syscall" "time" - "github.com/chewxy/math32" "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" @@ -287,23 +287,27 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } - switch reqEmbed := req.Input.(type) { + reqEmbed := []string{} + + switch embeddings := req.Input.(type) { case string: - if reqEmbed == "" { + if embeddings == "" { c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}}) return } + reqEmbed = []string{embeddings} case []any: - if reqEmbed == nil { + if len(embeddings) == 0 { c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}}) return } - for _, v := range reqEmbed { + for _, v := range embeddings { if _, ok := v.(string); !ok { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"}) return } + reqEmbed = append(reqEmbed, v.(string)) } default: c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"}) @@ -335,30 +339,16 @@ func (s *Server) EmbedHandler(c *gin.Context) { return s, nil } - embeddings := [][]float32{} - - switch reqEmbed := req.Input.(type) { - case string: - reqEmbed, err = checkFit(reqEmbed, *req.Truncate) + reqEmbedArray := make([]string, len(reqEmbed)) + for i, v := range reqEmbed { + s, err := checkFit(v, *req.Truncate) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - embeddings, err = r.Embed(c.Request.Context(), []string{reqEmbed}) - case []any: - reqEmbedArray := make([]string, len(reqEmbed)) - for i, v := range reqEmbed { - s, err := checkFit(v.(string), *req.Truncate) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - reqEmbedArray[i] = s - } - embeddings, err = r.Embed(c.Request.Context(), reqEmbedArray) - default: - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"}) + reqEmbedArray[i] = s } + embeddings, err := r.Embed(c.Request.Context(), reqEmbedArray) if err != nil { slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) @@ -385,7 +375,7 @@ func normalize(vec []float32) []float32 { norm := float32(0.0) if sum > 0 { - norm = float32(1.0 / math32.Sqrt(sum)) + norm = float32(1.0 / math.Sqrt(float64(sum))) } for i := range vec {