refactoring
This commit is contained in:
parent
c697eb2a9b
commit
8f6d0242b6
21
llm/ext_server/server.cpp
vendored
21
llm/ext_server/server.cpp
vendored
@ -3202,19 +3202,18 @@ int main(int argc, char **argv) {
|
|||||||
// get the result
|
// get the result
|
||||||
task_result result = llama.queue_results.recv(id_task);
|
task_result result = llama.queue_results.recv(id_task);
|
||||||
llama.queue_results.remove_waiting_task_id(id_task);
|
llama.queue_results.remove_waiting_task_id(id_task);
|
||||||
if (!result.error) {
|
if (result.error) {
|
||||||
responses = result.result_json.value("results", std::vector<json>{result.result_json});
|
|
||||||
json embeddings = json::array();
|
|
||||||
for (auto & elem : responses) {
|
|
||||||
embeddings.push_back(elem.at("embedding"));
|
|
||||||
}
|
|
||||||
// send the result
|
|
||||||
json result = json{{"embedding", embeddings}};
|
|
||||||
return res.set_content(result.dump(), "application/json; charset=utf-8");
|
|
||||||
} else {
|
|
||||||
// return error
|
|
||||||
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
|
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
responses = result.result_json.value("results", std::vector<json>{result.result_json});
|
||||||
|
json embeddings = json::array();
|
||||||
|
for (auto & elem : responses) {
|
||||||
|
embeddings.push_back(elem.at("embedding"));
|
||||||
|
}
|
||||||
|
// send the result
|
||||||
|
json embedding_res = json{{"embedding", embeddings}};
|
||||||
|
return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@ -21,7 +22,6 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/chewxy/math32"
|
|
||||||
"github.com/gin-contrib/cors"
|
"github.com/gin-contrib/cors"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
@ -287,23 +287,27 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
switch reqEmbed := req.Input.(type) {
|
reqEmbed := []string{}
|
||||||
|
|
||||||
|
switch embeddings := req.Input.(type) {
|
||||||
case string:
|
case string:
|
||||||
if reqEmbed == "" {
|
if embeddings == "" {
|
||||||
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
reqEmbed = []string{embeddings}
|
||||||
case []any:
|
case []any:
|
||||||
if reqEmbed == nil {
|
if len(embeddings) == 0 {
|
||||||
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, v := range reqEmbed {
|
for _, v := range embeddings {
|
||||||
if _, ok := v.(string); !ok {
|
if _, ok := v.(string); !ok {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
reqEmbed = append(reqEmbed, v.(string))
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
||||||
@ -335,30 +339,16 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
embeddings := [][]float32{}
|
reqEmbedArray := make([]string, len(reqEmbed))
|
||||||
|
for i, v := range reqEmbed {
|
||||||
switch reqEmbed := req.Input.(type) {
|
s, err := checkFit(v, *req.Truncate)
|
||||||
case string:
|
|
||||||
reqEmbed, err = checkFit(reqEmbed, *req.Truncate)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
embeddings, err = r.Embed(c.Request.Context(), []string{reqEmbed})
|
reqEmbedArray[i] = s
|
||||||
case []any:
|
|
||||||
reqEmbedArray := make([]string, len(reqEmbed))
|
|
||||||
for i, v := range reqEmbed {
|
|
||||||
s, err := checkFit(v.(string), *req.Truncate)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
reqEmbedArray[i] = s
|
|
||||||
}
|
|
||||||
embeddings, err = r.Embed(c.Request.Context(), reqEmbedArray)
|
|
||||||
default:
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
|
||||||
}
|
}
|
||||||
|
embeddings, err := r.Embed(c.Request.Context(), reqEmbedArray)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
||||||
@ -385,7 +375,7 @@ func normalize(vec []float32) []float32 {
|
|||||||
|
|
||||||
norm := float32(0.0)
|
norm := float32(0.0)
|
||||||
if sum > 0 {
|
if sum > 0 {
|
||||||
norm = float32(1.0 / math32.Sqrt(sum))
|
norm = float32(1.0 / math.Sqrt(float64(sum)))
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range vec {
|
for i := range vec {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user