Truncation Integration Tests
This commit is contained in:
parent
e068e7f698
commit
1a0c8b363c
@ -19,7 +19,11 @@ func TestAllMiniLMEmbed(t *testing.T) {
|
|||||||
Input: "why is the sky blue?",
|
Input: "why is the sky blue?",
|
||||||
}
|
}
|
||||||
|
|
||||||
res := EmbedTestHelper(ctx, t, req)
|
res, err := EmbedTestHelper(ctx, t, req)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if len(res.Embeddings) != 1 {
|
if len(res.Embeddings) != 1 {
|
||||||
t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
|
t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
|
||||||
@ -28,6 +32,10 @@ func TestAllMiniLMEmbed(t *testing.T) {
|
|||||||
if len(res.Embeddings[0]) != 384 {
|
if len(res.Embeddings[0]) != 384 {
|
||||||
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
|
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if res.Embeddings[0][0] != 0.010071029038540258 {
|
||||||
|
t.Fatalf("expected 0.010071029038540258, got %f", res.Embeddings[0][0])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAllMiniLMBatchEmbed(t *testing.T) {
|
func TestAllMiniLMBatchEmbed(t *testing.T) {
|
||||||
@ -39,7 +47,11 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
|
|||||||
Input: []string{"why is the sky blue?", "why is the grass green?"},
|
Input: []string{"why is the sky blue?", "why is the grass green?"},
|
||||||
}
|
}
|
||||||
|
|
||||||
res := EmbedTestHelper(ctx, t, req)
|
res, err := EmbedTestHelper(ctx, t, req)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if len(res.Embeddings) != 2 {
|
if len(res.Embeddings) != 2 {
|
||||||
t.Fatalf("expected 2 embeddings, got %d", len(res.Embeddings))
|
t.Fatalf("expected 2 embeddings, got %d", len(res.Embeddings))
|
||||||
@ -48,4 +60,77 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
|
|||||||
if len(res.Embeddings[0]) != 384 {
|
if len(res.Embeddings[0]) != 384 {
|
||||||
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
|
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if res.Embeddings[0][0] != 0.010071029038540258 || res.Embeddings[1][0] != -0.00980270794235093 {
|
||||||
|
t.Fatalf("expected 0.010071029038540258 and -0.00980270794235093, got %f and %f", res.Embeddings[0][0], res.Embeddings[1][0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAllMiniLmEmbedTruncate(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
truncTrue, truncFalse := true, false
|
||||||
|
|
||||||
|
type testReq struct {
|
||||||
|
Name string
|
||||||
|
Request api.EmbedRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
reqs := []testReq{
|
||||||
|
{
|
||||||
|
Name: "Target Truncation",
|
||||||
|
Request: api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Default Truncate",
|
||||||
|
Request: api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why is the sky blue?",
|
||||||
|
Options: map[string]any{"num_ctx": 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Explicit Truncate",
|
||||||
|
Request: api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why is the sky blue?",
|
||||||
|
Truncate: &truncTrue,
|
||||||
|
Options: map[string]any{"num_ctx": 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
res := make(map[string]*api.EmbedResponse)
|
||||||
|
|
||||||
|
for _, req := range reqs {
|
||||||
|
response, err := EmbedTestHelper(ctx, t, req.Request)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error: %v", err)
|
||||||
|
}
|
||||||
|
res[req.Name] = response
|
||||||
|
}
|
||||||
|
|
||||||
|
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
|
||||||
|
t.Fatalf("expected default request to truncate correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
|
||||||
|
t.Fatalf("expected default request and truncate true request to be the same")
|
||||||
|
}
|
||||||
|
|
||||||
|
// check that truncate set to false returns an error if context length is exceeded
|
||||||
|
_, err := EmbedTestHelper(ctx, t, api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why is the sky blue?",
|
||||||
|
Truncate: &truncFalse,
|
||||||
|
Options: map[string]any{"num_ctx": 1},
|
||||||
|
})
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error, got nil")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -342,7 +342,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func EmbedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) *api.EmbedResponse {
|
func EmbedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||||
@ -350,8 +350,8 @@ func EmbedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) *a
|
|||||||
response, err := client.Embed(ctx, &req)
|
response, err := client.Embed(ctx, &req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error making request: %v", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return response
|
return response, nil
|
||||||
}
|
}
|
||||||
|
@ -395,7 +395,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
truncate := func(s string) (string, error) {
|
checkFit := func(s string, truncate bool) (string, error) {
|
||||||
tokens, err := runner.llama.Tokenize(c.Request.Context(), s)
|
tokens, err := runner.llama.Tokenize(c.Request.Context(), s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
@ -403,8 +403,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(tokens) > opts.NumCtx {
|
if len(tokens) > opts.NumCtx {
|
||||||
|
if truncate {
|
||||||
tokens = tokens[:opts.NumCtx]
|
tokens = tokens[:opts.NumCtx]
|
||||||
return runner.llama.Detokenize(c.Request.Context(), tokens)
|
return runner.llama.Detokenize(c.Request.Context(), tokens)
|
||||||
|
} else {
|
||||||
|
return "", fmt.Errorf("input length exceeds maximum context length")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
@ -418,13 +422,11 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}})
|
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if *req.Truncate {
|
reqEmbed, err = checkFit(reqEmbed, *req.Truncate)
|
||||||
reqEmbed, err = truncate(reqEmbed)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
|
||||||
embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed})
|
embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed})
|
||||||
case []any:
|
case []any:
|
||||||
if reqEmbed == nil {
|
if reqEmbed == nil {
|
||||||
@ -435,13 +437,11 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
reqEmbedArray := make([]string, len(reqEmbed))
|
reqEmbedArray := make([]string, len(reqEmbed))
|
||||||
for i, v := range reqEmbed {
|
for i, v := range reqEmbed {
|
||||||
if s, ok := v.(string); ok {
|
if s, ok := v.(string); ok {
|
||||||
if *req.Truncate {
|
s, err = checkFit(s, *req.Truncate)
|
||||||
s, err = truncate(s)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
|
||||||
reqEmbedArray[i] = s
|
reqEmbedArray[i] = s
|
||||||
} else {
|
} else {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
||||||
|
Loading…
x
Reference in New Issue
Block a user