testing clean up

This commit is contained in:
Roy Han 2024-07-11 20:20:14 -07:00
parent dbe9527305
commit 53e9576f46
2 changed files with 8 additions and 7 deletions

View File

@ -8,7 +8,6 @@ import (
"time" "time"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/stretchr/testify/require"
) )
func TestAllMiniLMEmbed(t *testing.T) { func TestAllMiniLMEmbed(t *testing.T) {
@ -116,11 +115,11 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
} }
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] { if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
t.Fatalf("expected default request to truncate correctly") t.Fatal("expected default request to truncate correctly")
} }
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] { if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
t.Fatalf("expected default request and truncate true request to be the same") t.Fatal("expected default request and truncate true request to be the same")
} }
// check that truncate set to false returns an error if context length is exceeded // check that truncate set to false returns an error if context length is exceeded
@ -132,14 +131,16 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
}) })
if err == nil { if err == nil {
t.Fatalf("expected error, got nil") t.Fatal("expected error, got nil")
} }
} }
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) { func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t) client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup() defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model)) if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err)
}
response, err := client.Embed(ctx, &req) response, err := client.Embed(ctx, &req)

View File

@ -483,7 +483,7 @@ func TestNormalize(t *testing.T) {
{input: []float32{0, 0, 0}}, {input: []float32{0, 0, 0}},
} }
assertNorm := func(vec []float32) (res bool) { isNormalized := func(vec []float32) (res bool) {
sum := 0.0 sum := 0.0
for _, v := range vec { for _, v := range vec {
sum += float64(v * v) sum += float64(v * v)
@ -498,7 +498,7 @@ func TestNormalize(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
normalized := normalize(tc.input) normalized := normalize(tc.input)
if !assertNorm(normalized) { if !isNormalized(normalized) {
t.Errorf("Vector %v is not normalized", tc.input) t.Errorf("Vector %v is not normalized", tc.input)
} }
}) })