move normalization to go
This commit is contained in:
parent
9c32b6b9ed
commit
aee25acb5b
@ -240,6 +240,7 @@ type EmbeddingRequest struct {
|
||||
|
||||
// EmbedResponse is the response from [Client.Embed].
|
||||
type EmbedResponse struct {
|
||||
Model string `json:"model"`
|
||||
Embeddings [][]float64 `json:"embeddings,omitempty"`
|
||||
}
|
||||
|
||||
|
25
format/normalize.go
Normal file
25
format/normalize.go
Normal file
@ -0,0 +1,25 @@
|
||||
package format
|
||||
|
||||
import "math"
|
||||
|
||||
func Normalize(vec []float64) []float64 {
|
||||
var sum float64
|
||||
for _, v := range vec {
|
||||
sum += v * v
|
||||
}
|
||||
|
||||
sum = math.Sqrt(sum)
|
||||
|
||||
var norm float64
|
||||
|
||||
if sum > 0 {
|
||||
norm = 1.0 / sum
|
||||
} else {
|
||||
norm = 0.0
|
||||
}
|
||||
|
||||
for i := range vec {
|
||||
vec[i] *= norm
|
||||
}
|
||||
return vec
|
||||
}
|
41
format/normalize_test.go
Normal file
41
format/normalize_test.go
Normal file
@ -0,0 +1,41 @@
|
||||
package format
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormalize(t *testing.T) {
|
||||
type testCase struct {
|
||||
input []float64
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{input: []float64{1}},
|
||||
{input: []float64{0, 1, 2, 3}},
|
||||
{input: []float64{0.1, 0.2, 0.3}},
|
||||
{input: []float64{-0.1, 0.2, 0.3, -0.4}},
|
||||
{input: []float64{0, 0, 0}},
|
||||
}
|
||||
|
||||
assertNorm := func(vec []float64) (res bool) {
|
||||
sum := 0.0
|
||||
for _, v := range vec {
|
||||
sum += v * v
|
||||
}
|
||||
if math.Abs(sum-1) > 1e-6 {
|
||||
return sum == 0
|
||||
} else {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
normalized := Normalize(tc.input)
|
||||
if !assertNorm(normalized) {
|
||||
t.Errorf("Vector %v is not normalized", tc.input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
14
llm/ext_server/server.cpp
vendored
14
llm/ext_server/server.cpp
vendored
@ -3191,21 +3191,11 @@ int main(int argc, char **argv) {
|
||||
responses = std::vector<json>(1, result.result_json);
|
||||
}
|
||||
json embeddings = json::array();
|
||||
if (body["normalize"]) {
|
||||
for (auto & elem : responses) {
|
||||
std::vector<float> embedding = elem.at("embedding").get<std::vector<float>>();
|
||||
embedding = normalize_vector(embedding, embedding.size());
|
||||
embeddings.push_back(embedding);
|
||||
}
|
||||
} else {
|
||||
for (auto & elem : responses) {
|
||||
embeddings.push_back(elem.at("embedding"));
|
||||
}
|
||||
for (auto & elem : responses) {
|
||||
embeddings.push_back(elem.at("embedding"));
|
||||
}
|
||||
// send the result
|
||||
json result = json{{"embedding", embeddings}};
|
||||
// log result
|
||||
|
||||
return res.set_content(result.dump(), "application/json; charset=utf-8");
|
||||
} else {
|
||||
// return error
|
||||
|
28
llm/ext_server/utils.hpp
vendored
28
llm/ext_server/utils.hpp
vendored
@ -657,19 +657,19 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector<com
|
||||
return out;
|
||||
}
|
||||
|
||||
// normalize a vector
|
||||
static std::vector<float> normalize_vector(const std::vector<float>& vec, int size) {
|
||||
double sum = 0.0;
|
||||
for (float value : vec) {
|
||||
sum += value * value;
|
||||
}
|
||||
sum = std::sqrt(sum);
|
||||
// // normalize a vector
|
||||
// static std::vector<float> normalize_vector(const std::vector<float>& vec, int size) {
|
||||
// double sum = 0.0;
|
||||
// for (float value : vec) {
|
||||
// sum += value * value;
|
||||
// }
|
||||
// sum = std::sqrt(sum);
|
||||
|
||||
const float norm = sum > 0.0 ? 1.0f / sum : 0.0f;
|
||||
// const float norm = sum > 0.0 ? 1.0f / sum : 0.0f;
|
||||
|
||||
std::vector<float> normalized_vec(size);
|
||||
for (int i = 0; i < size; i++) {
|
||||
normalized_vec[i] = vec[i] * norm;
|
||||
}
|
||||
return normalized_vec;
|
||||
}
|
||||
// std::vector<float> normalized_vec(size);
|
||||
// for (int i = 0; i < size; i++) {
|
||||
// normalized_vec[i] = vec[i] * norm;
|
||||
// }
|
||||
// return normalized_vec;
|
||||
// }
|
||||
|
@ -843,8 +843,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
}
|
||||
|
||||
type EmbedRequest struct {
|
||||
Content []string `json:"content"`
|
||||
Normalize bool `json:"normalize"`
|
||||
Content []string `json:"content"`
|
||||
}
|
||||
|
||||
type EmbedResponse struct {
|
||||
@ -866,7 +865,7 @@ func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float64, err
|
||||
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
}
|
||||
|
||||
data, err := json.Marshal(EmbedRequest{Content: input, Normalize: true})
|
||||
data, err := json.Marshal(EmbedRequest{Content: input})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
||||
}
|
||||
@ -902,8 +901,7 @@ func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float64, err
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Content string `json:"content"`
|
||||
Normalize bool `json:"normalize"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
@ -925,7 +923,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
|
||||
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
}
|
||||
|
||||
data, err := json.Marshal(EmbeddingRequest{Content: prompt, Normalize: false})
|
||||
data, err := json.Marshal(EmbeddingRequest{Content: prompt})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
||||
}
|
||||
|
@ -27,6 +27,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/gpu"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/openai"
|
||||
@ -458,20 +459,14 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// assert that embedding is normalized
|
||||
for _, e := range embeddings {
|
||||
sum := 0.0
|
||||
for _, v := range e {
|
||||
sum += v * v
|
||||
}
|
||||
if math.Abs(sum-1) > 1e-6 {
|
||||
slog.Info("embedding is not normalized", "sum", sum)
|
||||
} else {
|
||||
slog.Info("embedding is normalized", "sum", sum)
|
||||
}
|
||||
for i, e := range embeddings {
|
||||
embeddings[i] = format.Normalize(e)
|
||||
}
|
||||
|
||||
resp := api.EmbedResponse{Embeddings: embeddings}
|
||||
resp := api.EmbedResponse{
|
||||
Model: req.Model,
|
||||
Embeddings: embeddings,
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user