From 7bec2724a56ad990d66b9ca05e2b47191955596c Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Tue, 29 Apr 2025 11:57:54 -0700 Subject: [PATCH] integration: fix embedding tests error handling (#10478) The cleanup routine from InitServerconnection should run in the defer of the test case to properly detect failures and report the server logs --- integration/embed_test.go | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/integration/embed_test.go b/integration/embed_test.go index 8a95816a5..09369dbb4 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -34,13 +34,15 @@ func cosineSimilarity[V float32 | float64](v1, v2 []V) V { func TestAllMiniLMEmbeddings(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() req := api.EmbeddingRequest{ Model: "all-minilm", Prompt: "why is the sky blue?", } - res, err := embeddingTestHelper(ctx, t, req) + res, err := embeddingTestHelper(ctx, client, t, req) if err != nil { t.Fatalf("error: %v", err) @@ -62,13 +64,15 @@ func TestAllMiniLMEmbeddings(t *testing.T) { func TestAllMiniLMEmbed(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() req := api.EmbedRequest{ Model: "all-minilm", Input: "why is the sky blue?", } - res, err := embedTestHelper(ctx, t, req) + res, err := embedTestHelper(ctx, client, t, req) if err != nil { t.Fatalf("error: %v", err) @@ -98,13 +102,15 @@ func TestAllMiniLMEmbed(t *testing.T) { func TestAllMiniLMBatchEmbed(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() req := api.EmbedRequest{ Model: "all-minilm", Input: []string{"why is the sky blue?", "why is the grass green?"}, } - res, err := embedTestHelper(ctx, t, req) + res, err := embedTestHelper(ctx, client, t, req) if err != nil { t.Fatalf("error: %v", err) @@ -144,6 +150,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) { func TestAllMiniLMEmbedTruncate(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() truncTrue, truncFalse := true, false @@ -182,7 +190,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { res := make(map[string]*api.EmbedResponse) for _, req := range reqs { - response, err := embedTestHelper(ctx, t, req.Request) + response, err := embedTestHelper(ctx, client, t, req.Request) if err != nil { t.Fatalf("error: %v", err) } @@ -198,7 +206,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { } // check that truncate set to false returns an error if context length is exceeded - _, err := embedTestHelper(ctx, t, api.EmbedRequest{ + _, err := embedTestHelper(ctx, client, t, api.EmbedRequest{ Model: "all-minilm", Input: "why is the sky blue?", Truncate: &truncFalse, @@ -210,9 +218,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { } } -func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) { - client, _, cleanup := InitServerConnection(ctx, t) - defer cleanup() +func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) { if err := PullIfMissing(ctx, client, req.Model); err != nil { t.Fatalf("failed to pull model %s: %v", req.Model, err) } @@ -226,9 +232,7 @@ func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingReq return response, nil } -func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) { - client, _, cleanup := InitServerConnection(ctx, t) - defer cleanup() +func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) { if err := PullIfMissing(ctx, client, req.Model); err != nil { t.Fatalf("failed to pull model %s: %v", req.Model, err) }