diff --git a/integration/embed_test.go b/integration/embed_test.go index f13fd388a..8f3adb5b6 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/ollama/ollama/api" + "github.com/stretchr/testify/require" ) func TestAllMiniLMEmbed(t *testing.T) { @@ -19,7 +20,7 @@ func TestAllMiniLMEmbed(t *testing.T) { Input: "why is the sky blue?", } - res, err := EmbedTestHelper(ctx, t, req) + res, err := embedTestHelper(ctx, t, req) if err != nil { t.Fatalf("error: %v", err) @@ -47,7 +48,7 @@ func TestAllMiniLMBatchEmbed(t *testing.T) { Input: []string{"why is the sky blue?", "why is the grass green?"}, } - res, err := EmbedTestHelper(ctx, t, req) + res, err := embedTestHelper(ctx, t, req) if err != nil { t.Fatalf("error: %v", err) @@ -107,7 +108,7 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) { res := make(map[string]*api.EmbedResponse) for _, req := range reqs { - response, err := EmbedTestHelper(ctx, t, req.Request) + response, err := embedTestHelper(ctx, t, req.Request) if err != nil { t.Fatalf("error: %v", err) } @@ -123,7 +124,7 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) { } // check that truncate set to false returns an error if context length is exceeded - _, err := EmbedTestHelper(ctx, t, api.EmbedRequest{ + _, err := embedTestHelper(ctx, t, api.EmbedRequest{ Model: "all-minilm", Input: "why is the sky blue?", Truncate: &truncFalse, @@ -134,3 +135,17 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) { t.Fatalf("expected error, got nil") } } + +func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) { + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + require.NoError(t, PullIfMissing(ctx, client, req.Model)) + + response, err := client.Embed(ctx, &req) + + if err != nil { + return nil, err + } + + return response, nil +} diff --git a/integration/utils_test.go b/integration/utils_test.go index 552500201..7e1fcc10e 100644 --- a/integration/utils_test.go +++ b/integration/utils_test.go @@ -341,17 +341,3 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) { []string{"nitrogen", "oxygen", "carbon", "dioxide"}, } } - -func EmbedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) { - client, _, cleanup := InitServerConnection(ctx, t) - defer cleanup() - require.NoError(t, PullIfMissing(ctx, client, req.Model)) - - response, err := client.Embed(ctx, &req) - - if err != nil { - return nil, err - } - - return response, nil -} diff --git a/llm/ext_server/server.cpp b/llm/ext_server/server.cpp index 169627874..9b0dc9d83 100644 --- a/llm/ext_server/server.cpp +++ b/llm/ext_server/server.cpp @@ -3199,13 +3199,7 @@ int main(int argc, char **argv) { task_result result = llama.queue_results.recv(id_task); llama.queue_results.remove_waiting_task_id(id_task); if (!result.error) { - if (result.result_json.count("results")) { - // result for multi-task - responses = result.result_json.at("results"); - } else { - // result for single task - responses = std::vector(1, result.result_json); - } + responses = result.result_json.value("results", std::vector{result.result_json}); json embeddings = json::array(); for (auto & elem : responses) { embeddings.push_back(elem.at("embedding")); diff --git a/server/routes.go b/server/routes.go index 0e57251d4..593eeee20 100644 --- a/server/routes.go +++ b/server/routes.go @@ -9,7 +9,6 @@ import ( "io" "io/fs" "log/slog" - "math" "net" "net/http" "net/netip" @@ -21,6 +20,7 @@ import ( "syscall" "time" + "github.com/chewxy/math32" "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" @@ -443,14 +443,14 @@ func (s *Server) EmbedHandler(c *gin.Context) { } func normalize(vec []float32) []float32 { - var sum float64 + var sum float32 for _, v := range vec { - sum += float64(v * v) + sum += v * v } norm := float32(0.0) if sum > 0 { - norm = float32(1.0 / math.Sqrt(sum)) + norm = float32(1.0 / math32.Sqrt(sum)) } for i := range vec {