move normalization to go

This commit is contained in:
Roy Han 2024-07-01 14:10:58 -07:00
parent 9c32b6b9ed
commit aee25acb5b
7 changed files with 94 additions and 44 deletions

View File

@ -240,6 +240,7 @@ type EmbeddingRequest struct {
// EmbedResponse is the response from [Client.Embed].
type EmbedResponse struct {
Model string `json:"model"`
Embeddings [][]float64 `json:"embeddings,omitempty"`
}

25
format/normalize.go Normal file
View File

@ -0,0 +1,25 @@
package format
import "math"
func Normalize(vec []float64) []float64 {
var sum float64
for _, v := range vec {
sum += v * v
}
sum = math.Sqrt(sum)
var norm float64
if sum > 0 {
norm = 1.0 / sum
} else {
norm = 0.0
}
for i := range vec {
vec[i] *= norm
}
return vec
}

41
format/normalize_test.go Normal file
View File

@ -0,0 +1,41 @@
package format
import (
"math"
"testing"
)
func TestNormalize(t *testing.T) {
type testCase struct {
input []float64
}
testCases := []testCase{
{input: []float64{1}},
{input: []float64{0, 1, 2, 3}},
{input: []float64{0.1, 0.2, 0.3}},
{input: []float64{-0.1, 0.2, 0.3, -0.4}},
{input: []float64{0, 0, 0}},
}
assertNorm := func(vec []float64) (res bool) {
sum := 0.0
for _, v := range vec {
sum += v * v
}
if math.Abs(sum-1) > 1e-6 {
return sum == 0
} else {
return true
}
}
for _, tc := range testCases {
t.Run("", func(t *testing.T) {
normalized := Normalize(tc.input)
if !assertNorm(normalized) {
t.Errorf("Vector %v is not normalized", tc.input)
}
})
}
}

View File

@ -3191,21 +3191,11 @@ int main(int argc, char **argv) {
responses = std::vector<json>(1, result.result_json);
}
json embeddings = json::array();
if (body["normalize"]) {
for (auto & elem : responses) {
std::vector<float> embedding = elem.at("embedding").get<std::vector<float>>();
embedding = normalize_vector(embedding, embedding.size());
embeddings.push_back(embedding);
}
} else {
for (auto & elem : responses) {
embeddings.push_back(elem.at("embedding"));
}
for (auto & elem : responses) {
embeddings.push_back(elem.at("embedding"));
}
// send the result
json result = json{{"embedding", embeddings}};
// log result
return res.set_content(result.dump(), "application/json; charset=utf-8");
} else {
// return error

View File

@ -657,19 +657,19 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector<com
return out;
}
// normalize a vector
static std::vector<float> normalize_vector(const std::vector<float>& vec, int size) {
double sum = 0.0;
for (float value : vec) {
sum += value * value;
}
sum = std::sqrt(sum);
// // normalize a vector
// static std::vector<float> normalize_vector(const std::vector<float>& vec, int size) {
// double sum = 0.0;
// for (float value : vec) {
// sum += value * value;
// }
// sum = std::sqrt(sum);
const float norm = sum > 0.0 ? 1.0f / sum : 0.0f;
// const float norm = sum > 0.0 ? 1.0f / sum : 0.0f;
std::vector<float> normalized_vec(size);
for (int i = 0; i < size; i++) {
normalized_vec[i] = vec[i] * norm;
}
return normalized_vec;
}
// std::vector<float> normalized_vec(size);
// for (int i = 0; i < size; i++) {
// normalized_vec[i] = vec[i] * norm;
// }
// return normalized_vec;
// }

View File

@ -843,8 +843,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
}
type EmbedRequest struct {
Content []string `json:"content"`
Normalize bool `json:"normalize"`
Content []string `json:"content"`
}
type EmbedResponse struct {
@ -866,7 +865,7 @@ func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float64, err
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(EmbedRequest{Content: input, Normalize: true})
data, err := json.Marshal(EmbedRequest{Content: input})
if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err)
}
@ -902,8 +901,7 @@ func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float64, err
}
type EmbeddingRequest struct {
Content string `json:"content"`
Normalize bool `json:"normalize"`
Content string `json:"content"`
}
type EmbeddingResponse struct {
@ -925,7 +923,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(EmbeddingRequest{Content: prompt, Normalize: false})
data, err := json.Marshal(EmbeddingRequest{Content: prompt})
if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err)
}

View File

@ -27,6 +27,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai"
@ -458,20 +459,14 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
// assert that embedding is normalized
for _, e := range embeddings {
sum := 0.0
for _, v := range e {
sum += v * v
}
if math.Abs(sum-1) > 1e-6 {
slog.Info("embedding is not normalized", "sum", sum)
} else {
slog.Info("embedding is normalized", "sum", sum)
}
for i, e := range embeddings {
embeddings[i] = format.Normalize(e)
}
resp := api.EmbedResponse{Embeddings: embeddings}
resp := api.EmbedResponse{
Model: req.Model,
Embeddings: embeddings,
}
c.JSON(http.StatusOK, resp)
}