touches
This commit is contained in:
parent
3342e5f035
commit
bcb63e6e0e
@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAllMiniLMEmbed(t *testing.T) {
|
||||
@ -19,7 +20,7 @@ func TestAllMiniLMEmbed(t *testing.T) {
|
||||
Input: "why is the sky blue?",
|
||||
}
|
||||
|
||||
res, err := EmbedTestHelper(ctx, t, req)
|
||||
res, err := embedTestHelper(ctx, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
@ -47,7 +48,7 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
|
||||
Input: []string{"why is the sky blue?", "why is the grass green?"},
|
||||
}
|
||||
|
||||
res, err := EmbedTestHelper(ctx, t, req)
|
||||
res, err := embedTestHelper(ctx, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
@ -107,7 +108,7 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
|
||||
res := make(map[string]*api.EmbedResponse)
|
||||
|
||||
for _, req := range reqs {
|
||||
response, err := EmbedTestHelper(ctx, t, req.Request)
|
||||
response, err := embedTestHelper(ctx, t, req.Request)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
@ -123,7 +124,7 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
|
||||
}
|
||||
|
||||
// check that truncate set to false returns an error if context length is exceeded
|
||||
_, err := EmbedTestHelper(ctx, t, api.EmbedRequest{
|
||||
_, err := embedTestHelper(ctx, t, api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncFalse,
|
||||
@ -134,3 +135,17 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
|
||||
response, err := client.Embed(ctx, &req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
@ -341,17 +341,3 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
|
||||
}
|
||||
}
|
||||
|
||||
func EmbedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
|
||||
response, err := client.Embed(ctx, &req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
8
llm/ext_server/server.cpp
vendored
8
llm/ext_server/server.cpp
vendored
@ -3199,13 +3199,7 @@ int main(int argc, char **argv) {
|
||||
task_result result = llama.queue_results.recv(id_task);
|
||||
llama.queue_results.remove_waiting_task_id(id_task);
|
||||
if (!result.error) {
|
||||
if (result.result_json.count("results")) {
|
||||
// result for multi-task
|
||||
responses = result.result_json.at("results");
|
||||
} else {
|
||||
// result for single task
|
||||
responses = std::vector<json>(1, result.result_json);
|
||||
}
|
||||
responses = result.result_json.value("results", std::vector<json>{result.result_json});
|
||||
json embeddings = json::array();
|
||||
for (auto & elem : responses) {
|
||||
embeddings.push_back(elem.at("embedding"));
|
||||
|
@ -9,7 +9,6 @@ import (
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
@ -21,6 +20,7 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/chewxy/math32"
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
@ -443,14 +443,14 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
func normalize(vec []float32) []float32 {
|
||||
var sum float64
|
||||
var sum float32
|
||||
for _, v := range vec {
|
||||
sum += float64(v * v)
|
||||
sum += v * v
|
||||
}
|
||||
|
||||
norm := float32(0.0)
|
||||
if sum > 0 {
|
||||
norm = float32(1.0 / math.Sqrt(sum))
|
||||
norm = float32(1.0 / math32.Sqrt(sum))
|
||||
}
|
||||
|
||||
for i := range vec {
|
||||
|
Loading…
x
Reference in New Issue
Block a user