This commit is contained in:
Roy Han 2024-07-09 13:37:00 -07:00
parent 3342e5f035
commit bcb63e6e0e
4 changed files with 24 additions and 29 deletions

View File

@ -8,6 +8,7 @@ import (
"time"
"github.com/ollama/ollama/api"
"github.com/stretchr/testify/require"
)
func TestAllMiniLMEmbed(t *testing.T) {
@ -19,7 +20,7 @@ func TestAllMiniLMEmbed(t *testing.T) {
Input: "why is the sky blue?",
}
res, err := EmbedTestHelper(ctx, t, req)
res, err := embedTestHelper(ctx, t, req)
if err != nil {
t.Fatalf("error: %v", err)
@ -47,7 +48,7 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
Input: []string{"why is the sky blue?", "why is the grass green?"},
}
res, err := EmbedTestHelper(ctx, t, req)
res, err := embedTestHelper(ctx, t, req)
if err != nil {
t.Fatalf("error: %v", err)
@ -107,7 +108,7 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
res := make(map[string]*api.EmbedResponse)
for _, req := range reqs {
response, err := EmbedTestHelper(ctx, t, req.Request)
response, err := embedTestHelper(ctx, t, req.Request)
if err != nil {
t.Fatalf("error: %v", err)
}
@ -123,7 +124,7 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
}
// check that truncate set to false returns an error if context length is exceeded
_, err := EmbedTestHelper(ctx, t, api.EmbedRequest{
_, err := embedTestHelper(ctx, t, api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncFalse,
@ -134,3 +135,17 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
t.Fatalf("expected error, got nil")
}
}
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model))
response, err := client.Embed(ctx, &req)
if err != nil {
return nil, err
}
return response, nil
}

View File

@ -341,17 +341,3 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
}
}
func EmbedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model))
response, err := client.Embed(ctx, &req)
if err != nil {
return nil, err
}
return response, nil
}

View File

@ -3199,13 +3199,7 @@ int main(int argc, char **argv) {
task_result result = llama.queue_results.recv(id_task);
llama.queue_results.remove_waiting_task_id(id_task);
if (!result.error) {
if (result.result_json.count("results")) {
// result for multi-task
responses = result.result_json.at("results");
} else {
// result for single task
responses = std::vector<json>(1, result.result_json);
}
responses = result.result_json.value("results", std::vector<json>{result.result_json});
json embeddings = json::array();
for (auto & elem : responses) {
embeddings.push_back(elem.at("embedding"));

View File

@ -9,7 +9,6 @@ import (
"io"
"io/fs"
"log/slog"
"math"
"net"
"net/http"
"net/netip"
@ -21,6 +20,7 @@ import (
"syscall"
"time"
"github.com/chewxy/math32"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
@ -443,14 +443,14 @@ func (s *Server) EmbedHandler(c *gin.Context) {
}
func normalize(vec []float32) []float32 {
var sum float64
var sum float32
for _, v := range vec {
sum += float64(v * v)
sum += v * v
}
norm := float32(0.0)
if sum > 0 {
norm = float32(1.0 / math.Sqrt(sum))
norm = float32(1.0 / math32.Sqrt(sum))
}
for i := range vec {