From 3d060e0ae9bc3b7623b66e90970c0a6dfff2902c Mon Sep 17 00:00:00 2001 From: Roy Han Date: Tue, 2 Jul 2024 10:35:02 -0700 Subject: [PATCH] move normalize --- format/normalize.go | 25 ------------------------- {format => server}/normalize_test.go | 4 ++-- server/routes.go | 25 +++++++++++++++++++++++-- 3 files changed, 25 insertions(+), 29 deletions(-) delete mode 100644 format/normalize.go rename {format => server}/normalize_test.go (92%) diff --git a/format/normalize.go b/format/normalize.go deleted file mode 100644 index 67a42a933..000000000 --- a/format/normalize.go +++ /dev/null @@ -1,25 +0,0 @@ -package format - -import "math" - -func Normalize(vec []float32) []float32 { - var sum float64 - for _, v := range vec { - sum += float64(v * v) - } - - sum = math.Sqrt(sum) - - var norm float32 - - if sum > 0 { - norm = float32(1.0 / sum) - } else { - norm = 0.0 - } - - for i := range vec { - vec[i] *= norm - } - return vec -} diff --git a/format/normalize_test.go b/server/normalize_test.go similarity index 92% rename from format/normalize_test.go rename to server/normalize_test.go index 69e432ad3..f774e42d0 100644 --- a/format/normalize_test.go +++ b/server/normalize_test.go @@ -1,4 +1,4 @@ -package format +package server import ( "math" @@ -32,7 +32,7 @@ func TestNormalize(t *testing.T) { for _, tc := range testCases { t.Run("", func(t *testing.T) { - normalized := Normalize(tc.input) + normalized := normalize(tc.input) if !assertNorm(normalized) { t.Errorf("Vector %v is not normalized", tc.input) } diff --git a/server/routes.go b/server/routes.go index d493a8982..ee3240709 100644 --- a/server/routes.go +++ b/server/routes.go @@ -27,7 +27,6 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" - "github.com/ollama/ollama/format" "github.com/ollama/ollama/gpu" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/openai" @@ -460,7 +459,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { } for i, e := range embeddings { - embeddings[i] = format.Normalize(e) + embeddings[i] = normalize(e) } resp := api.EmbedResponse{ @@ -470,6 +469,28 @@ func (s *Server) EmbedHandler(c *gin.Context) { c.JSON(http.StatusOK, resp) } +func normalize(vec []float32) []float32 { + var sum float64 + for _, v := range vec { + sum += float64(v * v) + } + + sum = math.Sqrt(sum) + + var norm float32 + + if sum > 0 { + norm = float32(1.0 / sum) + } else { + norm = 0.0 + } + + for i := range vec { + vec[i] *= norm + } + return vec +} + func (s *Server) EmbeddingsHandler(c *gin.Context) { var req api.EmbeddingRequest err := c.ShouldBindJSON(&req)