move normalize

This commit is contained in:
Roy Han 2024-07-02 10:35:02 -07:00
parent 00a4cb26ca
commit 3d060e0ae9
3 changed files with 25 additions and 29 deletions

View File

@ -1,25 +0,0 @@
package format
import "math"
func Normalize(vec []float32) []float32 {
var sum float64
for _, v := range vec {
sum += float64(v * v)
}
sum = math.Sqrt(sum)
var norm float32
if sum > 0 {
norm = float32(1.0 / sum)
} else {
norm = 0.0
}
for i := range vec {
vec[i] *= norm
}
return vec
}

View File

@ -1,4 +1,4 @@
package format package server
import ( import (
"math" "math"
@ -32,7 +32,7 @@ func TestNormalize(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
normalized := Normalize(tc.input) normalized := normalize(tc.input)
if !assertNorm(normalized) { if !assertNorm(normalized) {
t.Errorf("Vector %v is not normalized", tc.input) t.Errorf("Vector %v is not normalized", tc.input)
} }

View File

@ -27,7 +27,6 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai" "github.com/ollama/ollama/openai"
@ -460,7 +459,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
} }
for i, e := range embeddings { for i, e := range embeddings {
embeddings[i] = format.Normalize(e) embeddings[i] = normalize(e)
} }
resp := api.EmbedResponse{ resp := api.EmbedResponse{
@ -470,6 +469,28 @@ func (s *Server) EmbedHandler(c *gin.Context) {
c.JSON(http.StatusOK, resp) c.JSON(http.StatusOK, resp)
} }
func normalize(vec []float32) []float32 {
var sum float64
for _, v := range vec {
sum += float64(v * v)
}
sum = math.Sqrt(sum)
var norm float32
if sum > 0 {
norm = float32(1.0 / sum)
} else {
norm = 0.0
}
for i := range vec {
vec[i] *= norm
}
return vec
}
func (s *Server) EmbeddingsHandler(c *gin.Context) { func (s *Server) EmbeddingsHandler(c *gin.Context) {
var req api.EmbeddingRequest var req api.EmbeddingRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)