diff --git a/integration/embed_test.go b/integration/embed_test.go index d1efda2f5..01baa265c 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -19,7 +19,11 @@ func TestAllMiniLMEmbed(t *testing.T) { Input: "why is the sky blue?", } - res := EmbedTestHelper(ctx, t, req) + res, err := EmbedTestHelper(ctx, t, req) + + if err != nil { + t.Fatalf("error: %v", err) + } if len(res.Embeddings) != 1 { t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings)) @@ -28,6 +32,10 @@ func TestAllMiniLMEmbed(t *testing.T) { if len(res.Embeddings[0]) != 384 { t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0])) } + + if res.Embeddings[0][0] != 0.010071029038540258 { + t.Fatalf("expected 0.010071029038540258, got %f", res.Embeddings[0][0]) + } } func TestAllMiniLMBatchEmbed(t *testing.T) { @@ -39,7 +47,11 @@ func TestAllMiniLMBatchEmbed(t *testing.T) { Input: []string{"why is the sky blue?", "why is the grass green?"}, } - res := EmbedTestHelper(ctx, t, req) + res, err := EmbedTestHelper(ctx, t, req) + + if err != nil { + t.Fatalf("error: %v", err) + } if len(res.Embeddings) != 2 { t.Fatalf("expected 2 embeddings, got %d", len(res.Embeddings)) @@ -48,4 +60,77 @@ func TestAllMiniLMBatchEmbed(t *testing.T) { if len(res.Embeddings[0]) != 384 { t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0])) } + + if res.Embeddings[0][0] != 0.010071029038540258 || res.Embeddings[1][0] != -0.00980270794235093 { + t.Fatalf("expected 0.010071029038540258 and -0.00980270794235093, got %f and %f", res.Embeddings[0][0], res.Embeddings[1][0]) + } +} + +func TestAllMiniLmEmbedTruncate(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + truncTrue, truncFalse := true, false + + type testReq struct { + Name string + Request api.EmbedRequest + } + + reqs := []testReq{ + { + Name: "Target Truncation", + Request: api.EmbedRequest{ + Model: "all-minilm", + Input: "why", + }, + }, + { + Name: "Default Truncate", + Request: api.EmbedRequest{ + Model: "all-minilm", + Input: "why is the sky blue?", + Options: map[string]any{"num_ctx": 1}, + }, + }, + { + Name: "Explicit Truncate", + Request: api.EmbedRequest{ + Model: "all-minilm", + Input: "why is the sky blue?", + Truncate: &truncTrue, + Options: map[string]any{"num_ctx": 1}, + }, + }, + } + + res := make(map[string]*api.EmbedResponse) + + for _, req := range reqs { + response, err := EmbedTestHelper(ctx, t, req.Request) + if err != nil { + t.Fatalf("error: %v", err) + } + res[req.Name] = response + } + + if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] { + t.Fatalf("expected default request to truncate correctly") + } + + if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] { + t.Fatalf("expected default request and truncate true request to be the same") + } + + // check that truncate set to false returns an error if context length is exceeded + _, err := EmbedTestHelper(ctx, t, api.EmbedRequest{ + Model: "all-minilm", + Input: "why is the sky blue?", + Truncate: &truncFalse, + Options: map[string]any{"num_ctx": 1}, + }) + + if err == nil { + t.Fatalf("expected error, got nil") + } } diff --git a/integration/utils_test.go b/integration/utils_test.go index c561e502a..552500201 100644 --- a/integration/utils_test.go +++ b/integration/utils_test.go @@ -342,7 +342,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) { } } -func EmbedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) *api.EmbedResponse { +func EmbedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) { client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() require.NoError(t, PullIfMissing(ctx, client, req.Model)) @@ -350,8 +350,8 @@ func EmbedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) *a response, err := client.Embed(ctx, &req) if err != nil { - t.Fatalf("Error making request: %v", err) + return nil, err } - return response + return response, nil } diff --git a/server/routes.go b/server/routes.go index ba2ac9eeb..90aeb3ba0 100644 --- a/server/routes.go +++ b/server/routes.go @@ -395,7 +395,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } - truncate := func(s string) (string, error) { + checkFit := func(s string, truncate bool) (string, error) { tokens, err := runner.llama.Tokenize(c.Request.Context(), s) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -403,8 +403,12 @@ func (s *Server) EmbedHandler(c *gin.Context) { } if len(tokens) > opts.NumCtx { - tokens = tokens[:opts.NumCtx] - return runner.llama.Detokenize(c.Request.Context(), tokens) + if truncate { + tokens = tokens[:opts.NumCtx] + return runner.llama.Detokenize(c.Request.Context(), tokens) + } else { + return "", fmt.Errorf("input length exceeds maximum context length") + } } return s, nil @@ -418,12 +422,10 @@ func (s *Server) EmbedHandler(c *gin.Context) { c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}}) return } - if *req.Truncate { - reqEmbed, err = truncate(reqEmbed) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } + reqEmbed, err = checkFit(reqEmbed, *req.Truncate) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return } embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed}) case []any: @@ -435,12 +437,10 @@ func (s *Server) EmbedHandler(c *gin.Context) { reqEmbedArray := make([]string, len(reqEmbed)) for i, v := range reqEmbed { if s, ok := v.(string); ok { - if *req.Truncate { - s, err = truncate(s) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } + s, err = checkFit(s, *req.Truncate) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return } reqEmbedArray[i] = s } else {