diff --git a/llm/ext_server/server.cpp b/llm/ext_server/server.cpp index cf1fe2eb7..801d5b755 100644 --- a/llm/ext_server/server.cpp +++ b/llm/ext_server/server.cpp @@ -3185,8 +3185,16 @@ int main(int argc, char **argv) { responses = std::vector(1, result.result_json); } json embeddings = json::array(); - for (auto & elem : responses) { - embeddings.push_back(json_value(elem, "embedding", json::array())); + if (body["normalize"]) { + for (auto & elem : responses) { + std::vector embedding = elem.at("embedding").get>(); + embedding = normalize_vector(embedding, embedding.size()); + embeddings.push_back(embedding); + } + } else { + for (auto & elem : responses) { + embeddings.push_back(elem.at("embedding")); + } } // send the result json result = json{{"embedding", embeddings}}; diff --git a/llm/ext_server/utils.hpp b/llm/ext_server/utils.hpp index d63ead04c..ade49f796 100644 --- a/llm/ext_server/utils.hpp +++ b/llm/ext_server/utils.hpp @@ -656,3 +656,20 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector normalize_vector(const std::vector& vec, int size) { + double sum = 0.0; + for (float value : vec) { + sum += value * value; + } + sum = std::sqrt(sum); + + const float norm = sum > 0.0 ? 1.0f / sum : 0.0f; + + std::vector normalized_vec(size); + for (int i = 0; i < size; i++) { + normalized_vec[i] = vec[i] * norm; + } + return normalized_vec; +} diff --git a/llm/server.go b/llm/server.go index ddd5b66bc..245d054d3 100644 --- a/llm/server.go +++ b/llm/server.go @@ -843,7 +843,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu } type EmbedRequest struct { - Content []string `json:"content"` + Content []string `json:"content"` + Normalize bool `json:"normalize"` } type EmbedResponse struct { @@ -865,7 +866,7 @@ func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float64, err return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) } - data, err := json.Marshal(EmbedRequest{Content: input}) + data, err := json.Marshal(EmbedRequest{Content: input, Normalize: true}) if err != nil { return nil, fmt.Errorf("error marshaling embed data: %w", err) } @@ -901,11 +902,12 @@ func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float64, err } type EmbeddingRequest struct { - Content string `json:"content"` + Content string `json:"content"` + Normalize bool `json:"normalize"` } type EmbeddingResponse struct { - Embedding []float64 `json:"embedding"` + Embedding [][]float64 `json:"embedding"` } func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) { @@ -923,7 +925,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) } - data, err := json.Marshal(TokenizeRequest{Content: prompt}) + data, err := json.Marshal(EmbeddingRequest{Content: prompt, Normalize: false}) if err != nil { return nil, fmt.Errorf("error marshaling embed data: %w", err) } @@ -955,7 +957,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er return nil, fmt.Errorf("unmarshal tokenize response: %w", err) } - return embedding.Embedding, nil + return embedding.Embedding[0], nil } type TokenizeRequest struct { diff --git a/server/routes.go b/server/routes.go index f1735c5f2..e6ff55b01 100644 --- a/server/routes.go +++ b/server/routes.go @@ -398,12 +398,22 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed}) - case []string: + case []any: if reqEmbed == nil { c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}}) return } - embeddings, err = runner.llama.Embed(c.Request.Context(), reqEmbed) + + reqEmbedArray := make([]string, len(reqEmbed)) + for i, v := range reqEmbed { + if s, ok := v.(string); ok { + reqEmbedArray[i] = s + } else { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"}) + return + } + } + embeddings, err = runner.llama.Embed(c.Request.Context(), reqEmbedArray) default: c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"}) } @@ -414,6 +424,19 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } + // assert that embedding is normalized + for _, e := range embeddings { + sum := 0.0 + for _, v := range e { + sum += v * v + } + if math.Abs(sum-1) > 1e-6 { + slog.Info("embedding is not normalized", "sum", sum) + } else { + slog.Info("embedding is normalized", "sum", sum) + } + } + resp := api.EmbedResponse{Embeddings: embeddings} c.JSON(http.StatusOK, resp) } @@ -486,7 +509,9 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { for _, v := range embedding { sum += v * v } - if math.Abs(sum-1) > 1e-6 { + if math.Abs(sum-1) < 1e-6 { + slog.Info("embedding is normalized", "sum", sum) + } else { slog.Info("embedding is not normalized", "sum", sum) }