testing clean up
This commit is contained in:
parent
dbe9527305
commit
53e9576f46
@ -8,7 +8,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAllMiniLMEmbed(t *testing.T) {
|
func TestAllMiniLMEmbed(t *testing.T) {
|
||||||
@ -116,11 +115,11 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
|
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
|
||||||
t.Fatalf("expected default request to truncate correctly")
|
t.Fatal("expected default request to truncate correctly")
|
||||||
}
|
}
|
||||||
|
|
||||||
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
|
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
|
||||||
t.Fatalf("expected default request and truncate true request to be the same")
|
t.Fatal("expected default request and truncate true request to be the same")
|
||||||
}
|
}
|
||||||
|
|
||||||
// check that truncate set to false returns an error if context length is exceeded
|
// check that truncate set to false returns an error if context length is exceeded
|
||||||
@ -132,14 +131,16 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("expected error, got nil")
|
t.Fatal("expected error, got nil")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
|
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
||||||
|
}
|
||||||
|
|
||||||
response, err := client.Embed(ctx, &req)
|
response, err := client.Embed(ctx, &req)
|
||||||
|
|
||||||
|
@ -483,7 +483,7 @@ func TestNormalize(t *testing.T) {
|
|||||||
{input: []float32{0, 0, 0}},
|
{input: []float32{0, 0, 0}},
|
||||||
}
|
}
|
||||||
|
|
||||||
assertNorm := func(vec []float32) (res bool) {
|
isNormalized := func(vec []float32) (res bool) {
|
||||||
sum := 0.0
|
sum := 0.0
|
||||||
for _, v := range vec {
|
for _, v := range vec {
|
||||||
sum += float64(v * v)
|
sum += float64(v * v)
|
||||||
@ -498,7 +498,7 @@ func TestNormalize(t *testing.T) {
|
|||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run("", func(t *testing.T) {
|
t.Run("", func(t *testing.T) {
|
||||||
normalized := normalize(tc.input)
|
normalized := normalize(tc.input)
|
||||||
if !assertNorm(normalized) {
|
if !isNormalized(normalized) {
|
||||||
t.Errorf("Vector %v is not normalized", tc.input)
|
t.Errorf("Vector %v is not normalized", tc.input)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
Loading…
x
Reference in New Issue
Block a user