From c22d54895a280b54c727279d85a5fc94defb5a29 Mon Sep 17 00:00:00 2001 From: Roy Han Date: Tue, 18 Jun 2024 17:34:36 -0700 Subject: [PATCH] Initial Batch Embedding --- api/types.go | 5 +++-- llm/server.go | 51 ++++++++++++++++++++++++++++++++++++++++++-- server/routes.go | 42 ++++++++++++++++++++++++++++++++---- server/sched_test.go | 5 +++-- 4 files changed, 93 insertions(+), 10 deletions(-) diff --git a/api/types.go b/api/types.go index 7822a6034..0d6a23f2e 100644 --- a/api/types.go +++ b/api/types.go @@ -210,7 +210,8 @@ type EmbeddingRequest struct { Model string `json:"model"` // Prompt is the textual prompt to embed. - Prompt string `json:"prompt"` + // Prompt string `json:"prompt"` + Prompt interface{} `json:"prompt"` // KeepAlive controls how long the model will stay loaded in memory following // this request. @@ -222,7 +223,7 @@ type EmbeddingRequest struct { // EmbeddingResponse is the response from [Client.Embeddings]. type EmbeddingResponse struct { - Embedding []float64 `json:"embedding"` + Embedding [][]float64 `json:"embedding"` } // CreateRequest is the request passed to [Client.Create]. diff --git a/llm/server.go b/llm/server.go index 7de1ec1b1..da4cc3bf6 100644 --- a/llm/server.go +++ b/llm/server.go @@ -19,6 +19,7 @@ import ( "runtime" "strconv" "strings" + "sync" "time" "golang.org/x/sync/semaphore" @@ -33,7 +34,7 @@ type LlamaServer interface { Ping(ctx context.Context) error WaitUntilRunning(ctx context.Context) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error - Embedding(ctx context.Context, prompt string) ([]float64, error) + Embedding(ctx context.Context, prompts interface{}) ([][]float64, error) Tokenize(ctx context.Context, content string) ([]int, error) Detokenize(ctx context.Context, tokens []int) (string, error) Close() error @@ -849,7 +850,7 @@ type EmbeddingResponse struct { Embedding []float64 `json:"embedding"` } -func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) { +func (s *llmServer) Embedding(ctx context.Context, prompts interface{}) ([][]float64, error) { if err := s.sem.Acquire(ctx, 1); err != nil { slog.Error("Failed to acquire semaphore", "error", err) return nil, err @@ -864,6 +865,52 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) } + switch prompts := prompts.(type) { + case string: + // single prompt + embedding, err := s.EmbeddingSingle(ctx, prompts) + if err != nil { + return nil, err + } + return [][]float64{embedding}, nil + case []string: + // multiple prompts + errCh := make(chan error, 1) + successCh := make(chan [][]float64, 1) + num_prompts := len(prompts) + embeddings := make([][]float64, num_prompts) + var wg sync.WaitGroup + wg.Add(num_prompts) + for i, p := range prompts { + go func(i int, p string) { + defer wg.Done() + slog.Info("embedding", "prompt", p) + embedding, err := s.EmbeddingSingle(ctx, p) + if err != nil { + errCh <- err + return + } + embeddings[i] = embedding + }(i, p) + } + + go func() { + wg.Wait() + successCh <- embeddings + }() + + select { + case err := <-errCh: + return nil, err + case embeddings := <-successCh: + return embeddings, nil + } + default: + return nil, fmt.Errorf("unsupported prompt type: %T", prompts) + } +} + +func (s *llmServer) EmbeddingSingle(ctx context.Context, prompt string) ([]float64, error) { data, err := json.Marshal(TokenizeRequest{Content: prompt}) if err != nil { return nil, fmt.Errorf("error marshaling embed data: %w", err) diff --git a/server/routes.go b/server/routes.go index f36fe1b08..794cc840d 100644 --- a/server/routes.go +++ b/server/routes.go @@ -356,6 +356,27 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { return } + // if we want to stick with the one prompt format, we can use custom unmarshalling + // otherwise just have separate fields + + switch req.Prompt.(type) { + case string: + case []interface{}: + prompts := make([]string, len(req.Prompt.([]interface{}))) + for i, p := range req.Prompt.([]interface{}) { + if str, ok := p.(string); ok { + prompts[i] = str + } else { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "prompt must be a string or list of strings"}) + return + } + } + req.Prompt = prompts + default: + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "prompt must be a string or list of strings"}) + return + } + model, err := GetModel(req.Model) if err != nil { var pErr *fs.PathError @@ -389,13 +410,26 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { return } - // an empty request loads the model - if req.Prompt == "" { - c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}}) + var embedding [][]float64 + + switch prompt := req.Prompt.(type) { + case string: + if prompt == "" { + c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: [][]float64{}}) + return + } + embedding, err = runner.llama.Embedding(c.Request.Context(), prompt) + case []string: + if len(prompt) == 0 { + c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: [][]float64{}}) + return + } + embedding, err = runner.llama.Embedding(c.Request.Context(), prompt) + default: + c.AbortWithStatus(http.StatusInternalServerError) return } - embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt) if err != nil { slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) diff --git a/server/sched_test.go b/server/sched_test.go index 953288347..6c0cbfee1 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -608,7 +608,7 @@ type mockLlm struct { pingResp error waitResp error completionResp error - embeddingResp []float64 + embeddingResp [][]float64 embeddingRespErr error tokenizeResp []int tokenizeRespErr error @@ -626,7 +626,8 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { return s.completionResp } -func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) { + +func (s *mockLlm) Embedding(ctx context.Context, prompts interface{}) ([][]float64, error) { return s.embeddingResp, s.embeddingRespErr } func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {