From 53e9576f46d19bb53ccba7841d5611137d690d14 Mon Sep 17 00:00:00 2001 From: Roy Han Date: Thu, 11 Jul 2024 20:20:14 -0700 Subject: [PATCH] testing clean up --- integration/embed_test.go | 11 ++++++----- server/routes_test.go | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/integration/embed_test.go b/integration/embed_test.go index 87fdeb8f9..aeafa57b6 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -8,7 +8,6 @@ import ( "time" "github.com/ollama/ollama/api" - "github.com/stretchr/testify/require" ) func TestAllMiniLMEmbed(t *testing.T) { @@ -116,11 +115,11 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) { } if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] { - t.Fatalf("expected default request to truncate correctly") + t.Fatal("expected default request to truncate correctly") } if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] { - t.Fatalf("expected default request and truncate true request to be the same") + t.Fatal("expected default request and truncate true request to be the same") } // check that truncate set to false returns an error if context length is exceeded @@ -132,14 +131,16 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) { }) if err == nil { - t.Fatalf("expected error, got nil") + t.Fatal("expected error, got nil") } } func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) { client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() - require.NoError(t, PullIfMissing(ctx, client, req.Model)) + if err := PullIfMissing(ctx, client, req.Model); err != nil { + t.Fatalf("failed to pull model %s: %v", req.Model, err) + } response, err := client.Embed(ctx, &req) diff --git a/server/routes_test.go b/server/routes_test.go index 8666311cf..02c4ecd11 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -483,7 +483,7 @@ func TestNormalize(t *testing.T) { {input: []float32{0, 0, 0}}, } - assertNorm := func(vec []float32) (res bool) { + isNormalized := func(vec []float32) (res bool) { sum := 0.0 for _, v := range vec { sum += float64(v * v) @@ -498,7 +498,7 @@ func TestNormalize(t *testing.T) { for _, tc := range testCases { t.Run("", func(t *testing.T) { normalized := normalize(tc.input) - if !assertNorm(normalized) { + if !isNormalized(normalized) { t.Errorf("Vector %v is not normalized", tc.input) } })