Introduce /api/embed endpoint supporting batch embedding (#5127)
* Initial Batch Embedding * Revert "Initial Batch Embedding" This reverts commitc22d54895a. * Initial Draft * mock up notes * api/embed draft * add server function * check normalization * clean up * normalization * playing around with truncate stuff * Truncation * Truncation * move normalization to go * Integration Test Template * Truncation Integration Tests * Clean up * use float32 * move normalize * move normalize test * refactoring * integration float32 * input handling and handler testing * Refactoring of legacy and new * clear comments * merge conflicts * touches * embedding type 64 * merge conflicts * fix hanging on single string * refactoring * test values * set context length * clean up * testing clean up * testing clean up * remove function closure * Revert "remove function closure" This reverts commit55d48c6ed1. * remove function closure * remove redundant error check * clean up * more clean up * clean up
This commit is contained in:
152
integration/embed_test.go
Normal file
152
integration/embed_test.go
Normal file
@@ -0,0 +1,152 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestAllMiniLMEmbed(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
req := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(ctx, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if len(res.Embeddings) != 1 {
|
||||
t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
|
||||
}
|
||||
|
||||
if len(res.Embeddings[0]) != 384 {
|
||||
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
|
||||
}
|
||||
|
||||
if res.Embeddings[0][0] != 0.010071031 {
|
||||
t.Fatalf("expected 0.010071031, got %f", res.Embeddings[0][0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllMiniLMBatchEmbed(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
req := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: []string{"why is the sky blue?", "why is the grass green?"},
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(ctx, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if len(res.Embeddings) != 2 {
|
||||
t.Fatalf("expected 2 embeddings, got %d", len(res.Embeddings))
|
||||
}
|
||||
|
||||
if len(res.Embeddings[0]) != 384 {
|
||||
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
|
||||
}
|
||||
|
||||
if res.Embeddings[0][0] != 0.010071031 || res.Embeddings[1][0] != -0.009802706 {
|
||||
t.Fatalf("expected 0.010071031 and -0.009802706, got %f and %f", res.Embeddings[0][0], res.Embeddings[1][0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllMiniLmEmbedTruncate(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
truncTrue, truncFalse := true, false
|
||||
|
||||
type testReq struct {
|
||||
Name string
|
||||
Request api.EmbedRequest
|
||||
}
|
||||
|
||||
reqs := []testReq{
|
||||
{
|
||||
Name: "Target Truncation",
|
||||
Request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Default Truncate",
|
||||
Request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Options: map[string]any{"num_ctx": 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Explicit Truncate",
|
||||
Request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
res := make(map[string]*api.EmbedResponse)
|
||||
|
||||
for _, req := range reqs {
|
||||
response, err := embedTestHelper(ctx, t, req.Request)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
res[req.Name] = response
|
||||
}
|
||||
|
||||
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
|
||||
t.Fatal("expected default request to truncate correctly")
|
||||
}
|
||||
|
||||
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
|
||||
t.Fatal("expected default request and truncate true request to be the same")
|
||||
}
|
||||
|
||||
// check that truncate set to false returns an error if context length is exceeded
|
||||
_, err := embedTestHelper(ctx, t, api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncFalse,
|
||||
Options: map[string]any{"num_ctx": 1},
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
||||
}
|
||||
|
||||
response, err := client.Embed(ctx, &req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
Reference in New Issue
Block a user