move normalize

This commit is contained in:
Roy Han 2024-07-02 10:35:02 -07:00
parent 00a4cb26ca
commit 3d060e0ae9
3 changed files with 25 additions and 29 deletions

View File

@ -1,25 +0,0 @@
package format
import "math"
func Normalize(vec []float32) []float32 {
var sum float64
for _, v := range vec {
sum += float64(v * v)
}
sum = math.Sqrt(sum)
var norm float32
if sum > 0 {
norm = float32(1.0 / sum)
} else {
norm = 0.0
}
for i := range vec {
vec[i] *= norm
}
return vec
}

View File

@ -1,4 +1,4 @@
package format
package server
import (
"math"
@ -32,7 +32,7 @@ func TestNormalize(t *testing.T) {
for _, tc := range testCases {
t.Run("", func(t *testing.T) {
normalized := Normalize(tc.input)
normalized := normalize(tc.input)
if !assertNorm(normalized) {
t.Errorf("Vector %v is not normalized", tc.input)
}

View File

@ -27,7 +27,6 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai"
@ -460,7 +459,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
}
for i, e := range embeddings {
embeddings[i] = format.Normalize(e)
embeddings[i] = normalize(e)
}
resp := api.EmbedResponse{
@ -470,6 +469,28 @@ func (s *Server) EmbedHandler(c *gin.Context) {
c.JSON(http.StatusOK, resp)
}
func normalize(vec []float32) []float32 {
var sum float64
for _, v := range vec {
sum += float64(v * v)
}
sum = math.Sqrt(sum)
var norm float32
if sum > 0 {
norm = float32(1.0 / sum)
} else {
norm = 0.0
}
for i := range vec {
vec[i] *= norm
}
return vec
}
func (s *Server) EmbeddingsHandler(c *gin.Context) {
var req api.EmbeddingRequest
err := c.ShouldBindJSON(&req)