Integration Test Template
This commit is contained in:
parent
aee25acb5b
commit
e068e7f698
51
integration/embed_test.go
Normal file
51
integration/embed_test.go
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAllMiniLMEmbed(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req := api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why is the sky blue?",
|
||||||
|
}
|
||||||
|
|
||||||
|
res := EmbedTestHelper(ctx, t, req)
|
||||||
|
|
||||||
|
if len(res.Embeddings) != 1 {
|
||||||
|
t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Embeddings[0]) != 384 {
|
||||||
|
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAllMiniLMBatchEmbed(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req := api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: []string{"why is the sky blue?", "why is the grass green?"},
|
||||||
|
}
|
||||||
|
|
||||||
|
res := EmbedTestHelper(ctx, t, req)
|
||||||
|
|
||||||
|
if len(res.Embeddings) != 2 {
|
||||||
|
t.Fatalf("expected 2 embeddings, got %d", len(res.Embeddings))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Embeddings[0]) != 384 {
|
||||||
|
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
|
||||||
|
}
|
||||||
|
}
|
@ -341,3 +341,17 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
|||||||
[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
|
[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func EmbedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) *api.EmbedResponse {
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||||
|
|
||||||
|
response, err := client.Embed(ctx, &req)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error making request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return response
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user