normalization

This commit is contained in:
Roy Han 2024-06-28 17:19:04 -07:00
parent 5213c12354
commit c111d8bb51
4 changed files with 63 additions and 11 deletions

View File

@ -3185,8 +3185,16 @@ int main(int argc, char **argv) {
responses = std::vector<json>(1, result.result_json);
}
json embeddings = json::array();
for (auto & elem : responses) {
embeddings.push_back(json_value(elem, "embedding", json::array()));
if (body["normalize"]) {
for (auto & elem : responses) {
std::vector<float> embedding = elem.at("embedding").get<std::vector<float>>();
embedding = normalize_vector(embedding, embedding.size());
embeddings.push_back(embedding);
}
} else {
for (auto & elem : responses) {
embeddings.push_back(elem.at("embedding"));
}
}
// send the result
json result = json{{"embedding", embeddings}};

View File

@ -656,3 +656,20 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector<com
}
return out;
}
// normalize a vector
std::vector<float> normalize_vector(const std::vector<float>& vec, int size) {
double sum = 0.0;
for (float value : vec) {
sum += value * value;
}
sum = std::sqrt(sum);
const float norm = sum > 0.0 ? 1.0f / sum : 0.0f;
std::vector<float> normalized_vec(size);
for (int i = 0; i < size; i++) {
normalized_vec[i] = vec[i] * norm;
}
return normalized_vec;
}

View File

@ -843,7 +843,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
}
type EmbedRequest struct {
Content []string `json:"content"`
Content []string `json:"content"`
Normalize bool `json:"normalize"`
}
type EmbedResponse struct {
@ -865,7 +866,7 @@ func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float64, err
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(EmbedRequest{Content: input})
data, err := json.Marshal(EmbedRequest{Content: input, Normalize: true})
if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err)
}
@ -901,11 +902,12 @@ func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float64, err
}
type EmbeddingRequest struct {
Content string `json:"content"`
Content string `json:"content"`
Normalize bool `json:"normalize"`
}
type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"`
Embedding [][]float64 `json:"embedding"`
}
func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
@ -923,7 +925,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(TokenizeRequest{Content: prompt})
data, err := json.Marshal(EmbeddingRequest{Content: prompt, Normalize: false})
if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err)
}
@ -955,7 +957,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
}
return embedding.Embedding, nil
return embedding.Embedding[0], nil
}
type TokenizeRequest struct {

View File

@ -398,12 +398,22 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed})
case []string:
case []any:
if reqEmbed == nil {
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}})
return
}
embeddings, err = runner.llama.Embed(c.Request.Context(), reqEmbed)
reqEmbedArray := make([]string, len(reqEmbed))
for i, v := range reqEmbed {
if s, ok := v.(string); ok {
reqEmbedArray[i] = s
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return
}
}
embeddings, err = runner.llama.Embed(c.Request.Context(), reqEmbedArray)
default:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
}
@ -414,6 +424,19 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
// assert that embedding is normalized
for _, e := range embeddings {
sum := 0.0
for _, v := range e {
sum += v * v
}
if math.Abs(sum-1) > 1e-6 {
slog.Info("embedding is not normalized", "sum", sum)
} else {
slog.Info("embedding is normalized", "sum", sum)
}
}
resp := api.EmbedResponse{Embeddings: embeddings}
c.JSON(http.StatusOK, resp)
}
@ -486,7 +509,9 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
for _, v := range embedding {
sum += v * v
}
if math.Abs(sum-1) > 1e-6 {
if math.Abs(sum-1) < 1e-6 {
slog.Info("embedding is normalized", "sum", sum)
} else {
slog.Info("embedding is not normalized", "sum", sum)
}