refactoring

This commit is contained in:
Roy Han 2024-07-09 16:19:02 -07:00
parent c697eb2a9b
commit 8f6d0242b6
2 changed files with 25 additions and 36 deletions

View File

@ -3202,19 +3202,18 @@ int main(int argc, char **argv) {
// get the result // get the result
task_result result = llama.queue_results.recv(id_task); task_result result = llama.queue_results.recv(id_task);
llama.queue_results.remove_waiting_task_id(id_task); llama.queue_results.remove_waiting_task_id(id_task);
if (!result.error) { if (result.error) {
responses = result.result_json.value("results", std::vector<json>{result.result_json});
json embeddings = json::array();
for (auto & elem : responses) {
embeddings.push_back(elem.at("embedding"));
}
// send the result
json result = json{{"embedding", embeddings}};
return res.set_content(result.dump(), "application/json; charset=utf-8");
} else {
// return error
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
} }
responses = result.result_json.value("results", std::vector<json>{result.result_json});
json embeddings = json::array();
for (auto & elem : responses) {
embeddings.push_back(elem.at("embedding"));
}
// send the result
json embedding_res = json{{"embedding", embeddings}};
return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
} }
}); });

View File

@ -10,6 +10,7 @@ import (
"io" "io"
"io/fs" "io/fs"
"log/slog" "log/slog"
"math"
"net" "net"
"net/http" "net/http"
"net/netip" "net/netip"
@ -21,7 +22,6 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/chewxy/math32"
"github.com/gin-contrib/cors" "github.com/gin-contrib/cors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -287,23 +287,27 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return return
} }
switch reqEmbed := req.Input.(type) { reqEmbed := []string{}
switch embeddings := req.Input.(type) {
case string: case string:
if reqEmbed == "" { if embeddings == "" {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}}) c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return return
} }
reqEmbed = []string{embeddings}
case []any: case []any:
if reqEmbed == nil { if len(embeddings) == 0 {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}}) c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return return
} }
for _, v := range reqEmbed { for _, v := range embeddings {
if _, ok := v.(string); !ok { if _, ok := v.(string); !ok {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return return
} }
reqEmbed = append(reqEmbed, v.(string))
} }
default: default:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
@ -335,30 +339,16 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return s, nil return s, nil
} }
embeddings := [][]float32{} reqEmbedArray := make([]string, len(reqEmbed))
for i, v := range reqEmbed {
switch reqEmbed := req.Input.(type) { s, err := checkFit(v, *req.Truncate)
case string:
reqEmbed, err = checkFit(reqEmbed, *req.Truncate)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
embeddings, err = r.Embed(c.Request.Context(), []string{reqEmbed}) reqEmbedArray[i] = s
case []any:
reqEmbedArray := make([]string, len(reqEmbed))
for i, v := range reqEmbed {
s, err := checkFit(v.(string), *req.Truncate)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
reqEmbedArray[i] = s
}
embeddings, err = r.Embed(c.Request.Context(), reqEmbedArray)
default:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
} }
embeddings, err := r.Embed(c.Request.Context(), reqEmbedArray)
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
@ -385,7 +375,7 @@ func normalize(vec []float32) []float32 {
norm := float32(0.0) norm := float32(0.0)
if sum > 0 { if sum > 0 {
norm = float32(1.0 / math32.Sqrt(sum)) norm = float32(1.0 / math.Sqrt(float64(sum)))
} }
for i := range vec { for i := range vec {