This commit is contained in:
Roy Han 2024-07-09 13:37:00 -07:00
parent 3342e5f035
commit bcb63e6e0e
4 changed files with 24 additions and 29 deletions

View File

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/stretchr/testify/require"
) )
func TestAllMiniLMEmbed(t *testing.T) { func TestAllMiniLMEmbed(t *testing.T) {
@ -19,7 +20,7 @@ func TestAllMiniLMEmbed(t *testing.T) {
Input: "why is the sky blue?", Input: "why is the sky blue?",
} }
res, err := EmbedTestHelper(ctx, t, req) res, err := embedTestHelper(ctx, t, req)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
@ -47,7 +48,7 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
Input: []string{"why is the sky blue?", "why is the grass green?"}, Input: []string{"why is the sky blue?", "why is the grass green?"},
} }
res, err := EmbedTestHelper(ctx, t, req) res, err := embedTestHelper(ctx, t, req)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
@ -107,7 +108,7 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
res := make(map[string]*api.EmbedResponse) res := make(map[string]*api.EmbedResponse)
for _, req := range reqs { for _, req := range reqs {
response, err := EmbedTestHelper(ctx, t, req.Request) response, err := embedTestHelper(ctx, t, req.Request)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
} }
@ -123,7 +124,7 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
} }
// check that truncate set to false returns an error if context length is exceeded // check that truncate set to false returns an error if context length is exceeded
_, err := EmbedTestHelper(ctx, t, api.EmbedRequest{ _, err := embedTestHelper(ctx, t, api.EmbedRequest{
Model: "all-minilm", Model: "all-minilm",
Input: "why is the sky blue?", Input: "why is the sky blue?",
Truncate: &truncFalse, Truncate: &truncFalse,
@ -134,3 +135,17 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
t.Fatalf("expected error, got nil") t.Fatalf("expected error, got nil")
} }
} }
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model))
response, err := client.Embed(ctx, &req)
if err != nil {
return nil, err
}
return response, nil
}

View File

@ -341,17 +341,3 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
[]string{"nitrogen", "oxygen", "carbon", "dioxide"}, []string{"nitrogen", "oxygen", "carbon", "dioxide"},
} }
} }
func EmbedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model))
response, err := client.Embed(ctx, &req)
if err != nil {
return nil, err
}
return response, nil
}

View File

@ -3199,13 +3199,7 @@ int main(int argc, char **argv) {
task_result result = llama.queue_results.recv(id_task); task_result result = llama.queue_results.recv(id_task);
llama.queue_results.remove_waiting_task_id(id_task); llama.queue_results.remove_waiting_task_id(id_task);
if (!result.error) { if (!result.error) {
if (result.result_json.count("results")) { responses = result.result_json.value("results", std::vector<json>{result.result_json});
// result for multi-task
responses = result.result_json.at("results");
} else {
// result for single task
responses = std::vector<json>(1, result.result_json);
}
json embeddings = json::array(); json embeddings = json::array();
for (auto & elem : responses) { for (auto & elem : responses) {
embeddings.push_back(elem.at("embedding")); embeddings.push_back(elem.at("embedding"));

View File

@ -9,7 +9,6 @@ import (
"io" "io"
"io/fs" "io/fs"
"log/slog" "log/slog"
"math"
"net" "net"
"net/http" "net/http"
"net/netip" "net/netip"
@ -21,6 +20,7 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/chewxy/math32"
"github.com/gin-contrib/cors" "github.com/gin-contrib/cors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -443,14 +443,14 @@ func (s *Server) EmbedHandler(c *gin.Context) {
} }
func normalize(vec []float32) []float32 { func normalize(vec []float32) []float32 {
var sum float64 var sum float32
for _, v := range vec { for _, v := range vec {
sum += float64(v * v) sum += v * v
} }
norm := float32(0.0) norm := float32(0.0)
if sum > 0 { if sum > 0 {
norm = float32(1.0 / math.Sqrt(sum)) norm = float32(1.0 / math32.Sqrt(sum))
} }
for i := range vec { for i := range vec {