diff --git a/llm/ext_server/server.cpp b/llm/ext_server/server.cpp index 001194764..7f0e67519 100644 --- a/llm/ext_server/server.cpp +++ b/llm/ext_server/server.cpp @@ -3165,14 +3165,6 @@ int main(int argc, char **argv) { { input = ""; } - if (body.count("input") != 0) - { - input = body["input"]; - } - else - { - input = ""; - } // create and queue the task json responses; diff --git a/llm/server.go b/llm/server.go index 7de1ec1b1..ddd5b66bc 100644 --- a/llm/server.go +++ b/llm/server.go @@ -34,6 +34,7 @@ type LlamaServer interface { WaitUntilRunning(ctx context.Context) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error Embedding(ctx context.Context, prompt string) ([]float64, error) + Embed(ctx context.Context, input []string) ([][]float64, error) Tokenize(ctx context.Context, content string) ([]int, error) Detokenize(ctx context.Context, tokens []int) (string, error) Close() error @@ -841,6 +842,64 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu return nil } +type EmbedRequest struct { + Content []string `json:"content"` +} + +type EmbedResponse struct { + Embedding [][]float64 `json:"embedding"` +} + +func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float64, error) { + if err := s.sem.Acquire(ctx, 1); err != nil { + slog.Error("Failed to acquire semaphore", "error", err) + return nil, err + } + defer s.sem.Release(1) + + // Make sure the server is ready + status, err := s.getServerStatusRetry(ctx) + if err != nil { + return nil, err + } else if status != ServerStatusReady { + return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) + } + + data, err := json.Marshal(EmbedRequest{Content: input}) + if err != nil { + return nil, fmt.Errorf("error marshaling embed data: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data)) + if err != nil { + return nil, fmt.Errorf("error creating embed request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("do embedding request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading embed response: %w", err) + } + + if resp.StatusCode >= 400 { + log.Printf("llm encode error: %s", body) + return nil, fmt.Errorf("%s", body) + } + + var embedding EmbedResponse + if err := json.Unmarshal(body, &embedding); err != nil { + return nil, fmt.Errorf("unmarshal tokenize response: %w", err) + } + + return embedding.Embedding, nil +} + type EmbeddingRequest struct { Content string `json:"content"` } diff --git a/server/routes.go b/server/routes.go index 264bf596a..1db949f0b 100644 --- a/server/routes.go +++ b/server/routes.go @@ -397,13 +397,13 @@ func (s *Server) EmbedHandler(c *gin.Context) { c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}}) return } - embeddings, err = runner.llama.Embedding(c.Request.Context(), []string{reqEmbed}) + embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed}) case []string: if reqEmbed == nil { c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}}) return } - embeddings, err = runner.llama.Embedding(c.Request.Context(), reqEmbed) + embeddings, err = runner.llama.Embed(c.Request.Context(), reqEmbed) default: c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"}) } diff --git a/server/sched_test.go b/server/sched_test.go index 953288347..1038b68dc 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -610,6 +610,8 @@ type mockLlm struct { completionResp error embeddingResp []float64 embeddingRespErr error + embedResp [][]float64 + embedRespErr error tokenizeResp []int tokenizeRespErr error detokenizeResp string @@ -629,6 +631,9 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) { return s.embeddingResp, s.embeddingRespErr } +func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float64, error) { + return s.embedResp, s.embedRespErr +} func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) { return s.tokenizeResp, s.tokenizeRespErr }