From aee25acb5bfdeeec55599c24c7eb51270a2dde33 Mon Sep 17 00:00:00 2001 From: Roy Han Date: Mon, 1 Jul 2024 14:10:58 -0700 Subject: [PATCH] move normalization to go --- api/types.go | 1 + format/normalize.go | 25 ++++++++++++++++++++++++ format/normalize_test.go | 41 +++++++++++++++++++++++++++++++++++++++ llm/ext_server/server.cpp | 14 ++----------- llm/ext_server/utils.hpp | 28 +++++++++++++------------- llm/server.go | 10 ++++------ server/routes.go | 19 +++++++----------- 7 files changed, 94 insertions(+), 44 deletions(-) create mode 100644 format/normalize.go create mode 100644 format/normalize_test.go diff --git a/api/types.go b/api/types.go index 534d9b3e2..44e5c49ae 100644 --- a/api/types.go +++ b/api/types.go @@ -240,6 +240,7 @@ type EmbeddingRequest struct { // EmbedResponse is the response from [Client.Embed]. type EmbedResponse struct { + Model string `json:"model"` Embeddings [][]float64 `json:"embeddings,omitempty"` } diff --git a/format/normalize.go b/format/normalize.go new file mode 100644 index 000000000..15aa29b6a --- /dev/null +++ b/format/normalize.go @@ -0,0 +1,25 @@ +package format + +import "math" + +func Normalize(vec []float64) []float64 { + var sum float64 + for _, v := range vec { + sum += v * v + } + + sum = math.Sqrt(sum) + + var norm float64 + + if sum > 0 { + norm = 1.0 / sum + } else { + norm = 0.0 + } + + for i := range vec { + vec[i] *= norm + } + return vec +} diff --git a/format/normalize_test.go b/format/normalize_test.go new file mode 100644 index 000000000..fb18a1e6a --- /dev/null +++ b/format/normalize_test.go @@ -0,0 +1,41 @@ +package format + +import ( + "math" + "testing" +) + +func TestNormalize(t *testing.T) { + type testCase struct { + input []float64 + } + + testCases := []testCase{ + {input: []float64{1}}, + {input: []float64{0, 1, 2, 3}}, + {input: []float64{0.1, 0.2, 0.3}}, + {input: []float64{-0.1, 0.2, 0.3, -0.4}}, + {input: []float64{0, 0, 0}}, + } + + assertNorm := func(vec []float64) (res bool) { + sum := 0.0 + for _, v := range vec { + sum += v * v + } + if math.Abs(sum-1) > 1e-6 { + return sum == 0 + } else { + return true + } + } + + for _, tc := range testCases { + t.Run("", func(t *testing.T) { + normalized := Normalize(tc.input) + if !assertNorm(normalized) { + t.Errorf("Vector %v is not normalized", tc.input) + } + }) + } +} diff --git a/llm/ext_server/server.cpp b/llm/ext_server/server.cpp index b55d5f190..cb0463919 100644 --- a/llm/ext_server/server.cpp +++ b/llm/ext_server/server.cpp @@ -3191,21 +3191,11 @@ int main(int argc, char **argv) { responses = std::vector(1, result.result_json); } json embeddings = json::array(); - if (body["normalize"]) { - for (auto & elem : responses) { - std::vector embedding = elem.at("embedding").get>(); - embedding = normalize_vector(embedding, embedding.size()); - embeddings.push_back(embedding); - } - } else { - for (auto & elem : responses) { - embeddings.push_back(elem.at("embedding")); - } + for (auto & elem : responses) { + embeddings.push_back(elem.at("embedding")); } // send the result json result = json{{"embedding", embeddings}}; - // log result - return res.set_content(result.dump(), "application/json; charset=utf-8"); } else { // return error diff --git a/llm/ext_server/utils.hpp b/llm/ext_server/utils.hpp index ee63cf786..3e54d0c0e 100644 --- a/llm/ext_server/utils.hpp +++ b/llm/ext_server/utils.hpp @@ -657,19 +657,19 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector normalize_vector(const std::vector& vec, int size) { - double sum = 0.0; - for (float value : vec) { - sum += value * value; - } - sum = std::sqrt(sum); +// // normalize a vector +// static std::vector normalize_vector(const std::vector& vec, int size) { +// double sum = 0.0; +// for (float value : vec) { +// sum += value * value; +// } +// sum = std::sqrt(sum); - const float norm = sum > 0.0 ? 1.0f / sum : 0.0f; +// const float norm = sum > 0.0 ? 1.0f / sum : 0.0f; - std::vector normalized_vec(size); - for (int i = 0; i < size; i++) { - normalized_vec[i] = vec[i] * norm; - } - return normalized_vec; -} +// std::vector normalized_vec(size); +// for (int i = 0; i < size; i++) { +// normalized_vec[i] = vec[i] * norm; +// } +// return normalized_vec; +// } diff --git a/llm/server.go b/llm/server.go index 245d054d3..1d4ca460f 100644 --- a/llm/server.go +++ b/llm/server.go @@ -843,8 +843,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu } type EmbedRequest struct { - Content []string `json:"content"` - Normalize bool `json:"normalize"` + Content []string `json:"content"` } type EmbedResponse struct { @@ -866,7 +865,7 @@ func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float64, err return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) } - data, err := json.Marshal(EmbedRequest{Content: input, Normalize: true}) + data, err := json.Marshal(EmbedRequest{Content: input}) if err != nil { return nil, fmt.Errorf("error marshaling embed data: %w", err) } @@ -902,8 +901,7 @@ func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float64, err } type EmbeddingRequest struct { - Content string `json:"content"` - Normalize bool `json:"normalize"` + Content string `json:"content"` } type EmbeddingResponse struct { @@ -925,7 +923,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) } - data, err := json.Marshal(EmbeddingRequest{Content: prompt, Normalize: false}) + data, err := json.Marshal(EmbeddingRequest{Content: prompt}) if err != nil { return nil, fmt.Errorf("error marshaling embed data: %w", err) } diff --git a/server/routes.go b/server/routes.go index b305ddc76..ba2ac9eeb 100644 --- a/server/routes.go +++ b/server/routes.go @@ -27,6 +27,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/format" "github.com/ollama/ollama/gpu" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/openai" @@ -458,20 +459,14 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } - // assert that embedding is normalized - for _, e := range embeddings { - sum := 0.0 - for _, v := range e { - sum += v * v - } - if math.Abs(sum-1) > 1e-6 { - slog.Info("embedding is not normalized", "sum", sum) - } else { - slog.Info("embedding is normalized", "sum", sum) - } + for i, e := range embeddings { + embeddings[i] = format.Normalize(e) } - resp := api.EmbedResponse{Embeddings: embeddings} + resp := api.EmbedResponse{ + Model: req.Model, + Embeddings: embeddings, + } c.JSON(http.StatusOK, resp) }