normalization
This commit is contained in:
parent
5213c12354
commit
c111d8bb51
12
llm/ext_server/server.cpp
vendored
12
llm/ext_server/server.cpp
vendored
@ -3185,8 +3185,16 @@ int main(int argc, char **argv) {
|
||||
responses = std::vector<json>(1, result.result_json);
|
||||
}
|
||||
json embeddings = json::array();
|
||||
for (auto & elem : responses) {
|
||||
embeddings.push_back(json_value(elem, "embedding", json::array()));
|
||||
if (body["normalize"]) {
|
||||
for (auto & elem : responses) {
|
||||
std::vector<float> embedding = elem.at("embedding").get<std::vector<float>>();
|
||||
embedding = normalize_vector(embedding, embedding.size());
|
||||
embeddings.push_back(embedding);
|
||||
}
|
||||
} else {
|
||||
for (auto & elem : responses) {
|
||||
embeddings.push_back(elem.at("embedding"));
|
||||
}
|
||||
}
|
||||
// send the result
|
||||
json result = json{{"embedding", embeddings}};
|
||||
|
17
llm/ext_server/utils.hpp
vendored
17
llm/ext_server/utils.hpp
vendored
@ -656,3 +656,20 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector<com
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
// normalize a vector
|
||||
std::vector<float> normalize_vector(const std::vector<float>& vec, int size) {
|
||||
double sum = 0.0;
|
||||
for (float value : vec) {
|
||||
sum += value * value;
|
||||
}
|
||||
sum = std::sqrt(sum);
|
||||
|
||||
const float norm = sum > 0.0 ? 1.0f / sum : 0.0f;
|
||||
|
||||
std::vector<float> normalized_vec(size);
|
||||
for (int i = 0; i < size; i++) {
|
||||
normalized_vec[i] = vec[i] * norm;
|
||||
}
|
||||
return normalized_vec;
|
||||
}
|
||||
|
@ -843,7 +843,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
}
|
||||
|
||||
type EmbedRequest struct {
|
||||
Content []string `json:"content"`
|
||||
Content []string `json:"content"`
|
||||
Normalize bool `json:"normalize"`
|
||||
}
|
||||
|
||||
type EmbedResponse struct {
|
||||
@ -865,7 +866,7 @@ func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float64, err
|
||||
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
}
|
||||
|
||||
data, err := json.Marshal(EmbedRequest{Content: input})
|
||||
data, err := json.Marshal(EmbedRequest{Content: input, Normalize: true})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
||||
}
|
||||
@ -901,11 +902,12 @@ func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float64, err
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Content string `json:"content"`
|
||||
Content string `json:"content"`
|
||||
Normalize bool `json:"normalize"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Embedding [][]float64 `json:"embedding"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
|
||||
@ -923,7 +925,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
|
||||
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
}
|
||||
|
||||
data, err := json.Marshal(TokenizeRequest{Content: prompt})
|
||||
data, err := json.Marshal(EmbeddingRequest{Content: prompt, Normalize: false})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
||||
}
|
||||
@ -955,7 +957,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
|
||||
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
||||
}
|
||||
|
||||
return embedding.Embedding, nil
|
||||
return embedding.Embedding[0], nil
|
||||
}
|
||||
|
||||
type TokenizeRequest struct {
|
||||
|
@ -398,12 +398,22 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed})
|
||||
case []string:
|
||||
case []any:
|
||||
if reqEmbed == nil {
|
||||
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}})
|
||||
return
|
||||
}
|
||||
embeddings, err = runner.llama.Embed(c.Request.Context(), reqEmbed)
|
||||
|
||||
reqEmbedArray := make([]string, len(reqEmbed))
|
||||
for i, v := range reqEmbed {
|
||||
if s, ok := v.(string); ok {
|
||||
reqEmbedArray[i] = s
|
||||
} else {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
||||
return
|
||||
}
|
||||
}
|
||||
embeddings, err = runner.llama.Embed(c.Request.Context(), reqEmbedArray)
|
||||
default:
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
||||
}
|
||||
@ -414,6 +424,19 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// assert that embedding is normalized
|
||||
for _, e := range embeddings {
|
||||
sum := 0.0
|
||||
for _, v := range e {
|
||||
sum += v * v
|
||||
}
|
||||
if math.Abs(sum-1) > 1e-6 {
|
||||
slog.Info("embedding is not normalized", "sum", sum)
|
||||
} else {
|
||||
slog.Info("embedding is normalized", "sum", sum)
|
||||
}
|
||||
}
|
||||
|
||||
resp := api.EmbedResponse{Embeddings: embeddings}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
@ -486,7 +509,9 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
for _, v := range embedding {
|
||||
sum += v * v
|
||||
}
|
||||
if math.Abs(sum-1) > 1e-6 {
|
||||
if math.Abs(sum-1) < 1e-6 {
|
||||
slog.Info("embedding is normalized", "sum", sum)
|
||||
} else {
|
||||
slog.Info("embedding is not normalized", "sum", sum)
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user