diff --git a/integration/embed_test.go b/integration/embed_test.go new file mode 100644 index 000000000..d1efda2f5 --- /dev/null +++ b/integration/embed_test.go @@ -0,0 +1,51 @@ +//go:build integration + +package integration + +import ( + "context" + "testing" + "time" + + "github.com/ollama/ollama/api" +) + +func TestAllMiniLMEmbed(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + req := api.EmbedRequest{ + Model: "all-minilm", + Input: "why is the sky blue?", + } + + res := EmbedTestHelper(ctx, t, req) + + if len(res.Embeddings) != 1 { + t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings)) + } + + if len(res.Embeddings[0]) != 384 { + t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0])) + } +} + +func TestAllMiniLMBatchEmbed(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + req := api.EmbedRequest{ + Model: "all-minilm", + Input: []string{"why is the sky blue?", "why is the grass green?"}, + } + + res := EmbedTestHelper(ctx, t, req) + + if len(res.Embeddings) != 2 { + t.Fatalf("expected 2 embeddings, got %d", len(res.Embeddings)) + } + + if len(res.Embeddings[0]) != 384 { + t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0])) + } +} diff --git a/integration/utils_test.go b/integration/utils_test.go index 7e1fcc10e..c561e502a 100644 --- a/integration/utils_test.go +++ b/integration/utils_test.go @@ -341,3 +341,17 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) { []string{"nitrogen", "oxygen", "carbon", "dioxide"}, } } + +func EmbedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) *api.EmbedResponse { + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + require.NoError(t, PullIfMissing(ctx, client, req.Model)) + + response, err := client.Embed(ctx, &req) + + if err != nil { + t.Fatalf("Error making request: %v", err) + } + + return response +}