From 5213c1235436d99a59b144f87e43cbfc83d94d31 Mon Sep 17 00:00:00 2001 From: Roy Han Date: Fri, 28 Jun 2024 15:26:58 -0700 Subject: [PATCH] clean up --- api/types.go | 8 ++------ llm/ext_server/server.cpp | 8 ++++---- server/routes.go | 1 + 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/api/types.go b/api/types.go index 5fb2e9d39..64f61d24a 100644 --- a/api/types.go +++ b/api/types.go @@ -226,10 +226,7 @@ type EmbeddingRequest struct { Model string `json:"model"` // Prompt is the textual prompt to embed. - Prompt string `json:"prompt,omitempty"` - - // PromptBatch is a list of prompts to embed. - PromptBatch []string `json:"prompt_batch,omitempty"` + Prompt string `json:"prompt"` // KeepAlive controls how long the model will stay loaded in memory following // this request. @@ -246,8 +243,7 @@ type EmbedResponse struct { // EmbeddingResponse is the response from [Client.Embeddings]. type EmbeddingResponse struct { - Embedding []float64 `json:"embedding,omitempty"` - EmbeddingBatch [][]float64 `json:"embedding_batch,omitempty"` + Embedding []float64 `json:"embedding"` } // CreateRequest is the request passed to [Client.Create]. diff --git a/llm/ext_server/server.cpp b/llm/ext_server/server.cpp index 7f0e67519..cf1fe2eb7 100644 --- a/llm/ext_server/server.cpp +++ b/llm/ext_server/server.cpp @@ -3156,14 +3156,14 @@ int main(int argc, char **argv) { { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); - json input; + json prompt; if (body.count("content") != 0) { - input = body["content"]; + prompt = body["content"]; } else { - input = ""; + prompt = ""; } // create and queue the task @@ -3171,7 +3171,7 @@ int main(int argc, char **argv) { { const int id_task = llama.queue_tasks.get_new_id(); llama.queue_results.add_waiting_task_id(id_task); - llama.request_completion(id_task, {{"prompt", input}}, true, -1); + llama.request_completion(id_task, {{"prompt", prompt}}, true, -1); // get the result task_result result = llama.queue_results.recv(id_task); diff --git a/server/routes.go b/server/routes.go index 7cef7358a..f1735c5f2 100644 --- a/server/routes.go +++ b/server/routes.go @@ -473,6 +473,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}}) return } + embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt) if err != nil { slog.Info(fmt.Sprintf("embedding generation failed: %v", err))