move normalize
This commit is contained in:
parent
00a4cb26ca
commit
3d060e0ae9
@ -1,25 +0,0 @@
|
|||||||
package format
|
|
||||||
|
|
||||||
import "math"
|
|
||||||
|
|
||||||
func Normalize(vec []float32) []float32 {
|
|
||||||
var sum float64
|
|
||||||
for _, v := range vec {
|
|
||||||
sum += float64(v * v)
|
|
||||||
}
|
|
||||||
|
|
||||||
sum = math.Sqrt(sum)
|
|
||||||
|
|
||||||
var norm float32
|
|
||||||
|
|
||||||
if sum > 0 {
|
|
||||||
norm = float32(1.0 / sum)
|
|
||||||
} else {
|
|
||||||
norm = 0.0
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range vec {
|
|
||||||
vec[i] *= norm
|
|
||||||
}
|
|
||||||
return vec
|
|
||||||
}
|
|
@ -1,4 +1,4 @@
|
|||||||
package format
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"math"
|
"math"
|
||||||
@ -32,7 +32,7 @@ func TestNormalize(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run("", func(t *testing.T) {
|
t.Run("", func(t *testing.T) {
|
||||||
normalized := Normalize(tc.input)
|
normalized := normalize(tc.input)
|
||||||
if !assertNorm(normalized) {
|
if !assertNorm(normalized) {
|
||||||
t.Errorf("Vector %v is not normalized", tc.input)
|
t.Errorf("Vector %v is not normalized", tc.input)
|
||||||
}
|
}
|
@ -27,7 +27,6 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
|
||||||
"github.com/ollama/ollama/gpu"
|
"github.com/ollama/ollama/gpu"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/openai"
|
"github.com/ollama/ollama/openai"
|
||||||
@ -460,7 +459,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i, e := range embeddings {
|
for i, e := range embeddings {
|
||||||
embeddings[i] = format.Normalize(e)
|
embeddings[i] = normalize(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := api.EmbedResponse{
|
resp := api.EmbedResponse{
|
||||||
@ -470,6 +469,28 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, resp)
|
c.JSON(http.StatusOK, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalize(vec []float32) []float32 {
|
||||||
|
var sum float64
|
||||||
|
for _, v := range vec {
|
||||||
|
sum += float64(v * v)
|
||||||
|
}
|
||||||
|
|
||||||
|
sum = math.Sqrt(sum)
|
||||||
|
|
||||||
|
var norm float32
|
||||||
|
|
||||||
|
if sum > 0 {
|
||||||
|
norm = float32(1.0 / sum)
|
||||||
|
} else {
|
||||||
|
norm = 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range vec {
|
||||||
|
vec[i] *= norm
|
||||||
|
}
|
||||||
|
return vec
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||||
var req api.EmbeddingRequest
|
var req api.EmbeddingRequest
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user